Compare commits

...

15 Commits

Author SHA1 Message Date
jmorganca
e23ddd84b8 x/grammar: add experimental GPU accelerated constrained decoding package 2026-01-11 00:50:11 -08:00
Jeffrey Morgan
7cc2a653f2 dockerfile: remove unused COPY command (#13664) 2026-01-09 23:07:15 -08:00
Jeffrey Morgan
2584940016 Add z-image image generation prototype (#13659) 2026-01-09 21:09:46 -08:00
Michael
c6d4c0c7f2 Documentation edits made through Mintlify web editor 2026-01-09 21:29:03 -05:00
Parth Sareen
1ef4241727 x: request access for all commands, add welcome message (#13662) 2026-01-09 18:20:39 -08:00
Parth Sareen
68fafd3002 x: improve approval selector with clearer labels (#13663) 2026-01-09 17:08:12 -08:00
Parth Sareen
2b2cda7a2b api: implement anthropic api (#13600)
* api: add Anthropic Messages API compatibility layer

Add middleware to support the Anthropic Messages API format at /v1/messages.
This enables tools like Claude Code to work with Ollama local and cloud models through the
Anthropic API interface.
2026-01-09 11:53:36 -08:00
Daniel Hiltgen
3cfe9fe146 docker: add missing deps (#13654)
The new MLX library has extra dependencies.
2026-01-09 07:34:40 -08:00
Parth Sareen
a23b559b4c x: disable web search tool registration (#13656) 2026-01-09 01:42:20 -08:00
Daniel Hiltgen
33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00
Daniel Hiltgen
34d0c55ea5 Linux: switch to zstd compression (#13651)
With the upcoming addition of MLX, the linux bundle will exceed the
maximum github artifact size of 2G.  This change will bring the size
back down.

The install.sh changes support backwards compatibility for prior versions
thus should be safe to merge concurrently with this change.
2026-01-08 15:47:32 -08:00
Parth Sareen
53a5a9e9ae x: redesign agent UI with minimal styling (#13650) 2026-01-08 15:40:07 -08:00
Parth Sareen
e30e08a7d6 x: remove Ctrl+O tool output expansion feature (#13640) 2026-01-07 15:34:08 -08:00
Parth Sareen
12e2b3514a x: agent loop ux improvements (#13635) 2026-01-07 01:27:15 -08:00
Devon Rifkin
626af2d809 template: fix args-as-json rendering (#13636)
In #13525, I accidentally broke templates' ability to automatically
render tool call function arguments as JSON.

We do need these to be proper maps because we need templates to be able
to call range, which can't be done on custom types.
2026-01-06 18:33:57 -08:00
210 changed files with 43718 additions and 351 deletions

View File

@@ -68,6 +68,7 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -392,13 +393,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tgz
*.tar.zst
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -531,7 +532,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -147,14 +163,48 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,6 +83,28 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -140,6 +162,21 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,8 +131,36 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -153,6 +181,7 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -171,7 +200,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

778
anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,778 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

953
anthropic/anthropic_test.go Normal file
View File

@@ -0,0 +1,953 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -46,6 +46,8 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -96,6 +98,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -457,6 +463,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -520,6 +535,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -547,9 +563,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
}
return generateInteractive(cmd, opts)
@@ -821,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -1764,6 +1785,10 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -6,11 +6,14 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,8 +21,13 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
} `json:"text_config"`
}
@@ -33,8 +41,94 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p AdapterParameters) KV() ggml.KV {
func (p AdapterParameters) KV() KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
alpha = p.LoraParameters.Alpha
}
kv := ggml.KV{
kv := KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelConverter interface {
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -107,7 +206,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
KV(ofs.Config) KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -115,7 +214,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
arch := baseKV.Architecture()
if arch == "" {
return errors.New("architecture not set for the base model")
}
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
return nil, nil, err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
return nil, nil, err
}
if len(p.Architectures) < 1 {
return errors.New("unknown architecture")
return nil, nil, errors.New("unknown architecture")
}
var conv ModelConverter
@@ -217,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
return nil, nil, err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
return nil, nil, err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return err
return nil, nil, err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -254,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -263,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
func (p *bertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
func (p *commandrModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
func (m *deepseekocr) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,7 +1,5 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -9,7 +7,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,6 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,8 +3,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -55,7 +53,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
func (m *gemma3nModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
func (m *gptossModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
func (p *llamaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
func (p *llama4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,6 +7,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,13 +19,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
func (p *mistral3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
func (p *mixtralModel) KV(t *Tokenizer) KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
func (m *mllamaModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
func (p *olmoModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
func (p *phi3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen2Model) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen3Model) KV(t *Tokenizer) KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,6 +19,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -28,7 +29,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k, v := range kv {
for k := range kv.Keys() {
v := kv.Value(k)
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV map[string]any
BaseKV KV
Expected map[string]string
}

View File

@@ -14,6 +14,7 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -0,0 +1,406 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -32,7 +32,9 @@
"codeblocks": "system"
},
"contextual": {
"options": ["copy"]
"options": [
"copy"
]
},
"navbar": {
"links": [
@@ -52,7 +54,9 @@
"display": "simple"
},
"examples": {
"languages": ["curl"]
"languages": [
"curl"
]
}
},
"redirects": [
@@ -97,6 +101,7 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -139,7 +144,8 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility"
"/api/openai-compatibility",
"/api/anthropic-compatibility"
]
},
{

View File

@@ -0,0 +1,69 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: Linux
title: "Linux"
---
## Install
@@ -13,8 +13,7 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -113,11 +112,7 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
</Note>
## Customizing
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -1,5 +1,7 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -11,4 +13,8 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -239,6 +241,18 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
return err
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
return err
}
}

18
go.mod
View File

@@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.39.0
)
require (
@@ -30,8 +30,8 @@ require (
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
golang.org/x/mod v0.31.0
golang.org/x/tools v0.40.0
gonum.org/v1/gonum v0.15.0
)
@@ -81,11 +81,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
golang.org/x/net v0.48.0 // indirect
golang.org/x/term v0.38.0
golang.org/x/text v0.32.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
go.sum
View File

@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

149
middleware/anthropic.go Normal file
View File

@@ -0,0 +1,149 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -0,0 +1,584 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}

View File

@@ -21,6 +21,7 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -801,7 +802,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
base := map[string]any{"general.architecture": "test"}
var base convert.KV = map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

33
progress/stepbar.go Normal file
View File

@@ -0,0 +1,33 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -18,6 +18,7 @@ const (
CharCtrlL = 12
CharEnter = 13
CharNext = 14
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
CharPrev = 16
CharBckSearch = 18
CharFwdSearch = 19

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -11,12 +12,19 @@ func Execute(args []string) error {
}
var newRunner bool
if args[0] == "--ollama-engine" {
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if newRunner {
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -42,18 +42,39 @@ shift $(( $OPTIND - 1 ))
_build_darwin() {
for ARCH in $ARCHS; do
status "Building darwin $ARCH"
INSTALL_PREFIX=dist/darwin-$ARCH/
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
INSTALL_PREFIX=dist/darwin-$ARCH/
if [ "$ARCH" = "amd64" ]; then
status "Building darwin $ARCH dynamic backends"
cmake -B build/darwin-$ARCH \
BUILD_DIR=build/darwin-$ARCH
cmake -B $BUILD_DIR \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
-DMLX_ENGINE=ON \
-DMLX_ENABLE_X64_MAC=ON \
-DOLLAMA_RUNNER_DIR=./
cmake --build $BUILD_DIR --target ggml-cpu -j
cmake --build $BUILD_DIR --target mlx mlxc -j
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
-DOLLAMA_RUNNER_DIR=./ \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
cmake --build build/darwin-$ARCH --target ggml-cpu -j
cmake --install build/darwin-$ARCH --component CPU
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -61,10 +82,12 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -131,17 +154,23 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -149,7 +178,7 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then

View File

@@ -12,6 +12,17 @@ set -eu
. $(dirname $0)/env.sh
# Check for required tools
if ! command -v zstd >/dev/null 2>&1; then
echo "ERROR: zstd is required but not installed." >&2
echo "Please install zstd:" >&2
echo " - macOS: brew install zstd" >&2
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
echo " - Arch: sudo pacman -S zstd" >&2
exit 1
fi
mkdir -p dist
docker buildx build \
@@ -37,19 +48,68 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
fi

View File

@@ -66,6 +66,36 @@ if [ -n "$NEEDS" ]; then
exit 1
fi
# Function to download and extract with fallback from zst to tgz
download_and_extract() {
local url_base="$1"
local dest_dir="$2"
local filename="$3"
# Check if .tar.zst is available
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
# zst file exists - check if we have zstd tool
if ! available zstd; then
error "This version requires zstd for extraction. Please install zstd and try again:
- Debian/Ubuntu: sudo apt-get install zstd
- RHEL/CentOS/Fedora: sudo dnf install zstd
- Arch: sudo pacman -S zstd"
fi
status "Downloading ${filename}.tar.zst"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
return 0
fi
# Fall back to .tgz for older versions
status "Downloading ${filename}.tgz"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "${dest_dir}"
}
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
@@ -78,10 +108,7 @@ fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
@@ -91,15 +118,9 @@ fi
# Check for NVIDIA JetPack systems with additional downloads
if [ -f /etc/nv_tegra_release ] ; then
if grep R36 /etc/nv_tegra_release > /dev/null ; then
status "Downloading JetPack 6 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
status "Downloading JetPack 5 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
else
warning "Unsupported JetPack version detected. GPU may not be supported"
fi
@@ -222,10 +243,7 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
status "Downloading Linux ROCm ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
install_success
status "AMD GPU ready."

View File

@@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -454,7 +455,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil

View File

@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -73,6 +74,11 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
@@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)

View File

@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,9 +13,14 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -50,6 +50,8 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -1544,6 +1575,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -22,6 +22,7 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -41,7 +42,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
base := map[string]any{"general.architecture": "test"}
var base convert.KV = map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -543,6 +552,48 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -6,6 +6,7 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}

View File

@@ -381,6 +381,28 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -396,11 +418,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -413,7 +435,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments map[string]any
Arguments templateArgs
}
// templateMessage is a template-compatible representation of api.Message
@@ -446,7 +468,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: tool.Function.Parameters.Properties.ToMap(),
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
},
},
}
@@ -468,7 +490,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
},
})
}

View File

@@ -613,3 +613,159 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -3,12 +3,13 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
)
func (c Capability) String() string {

24
x/README.md Normal file
View File

@@ -0,0 +1,24 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -4,6 +4,7 @@ package agent
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
@@ -32,10 +33,29 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Always allow",
"2. Allow for this session",
"3. Deny",
}
// toolDisplayNames maps internal tool names to human-readable display names.
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
}
// ToolDisplayName returns the human-readable display name for a tool.
func ToolDisplayName(toolName string) string {
if displayName, ok := toolDisplayNames[toolName]; ok {
return displayName
}
// Default: capitalize first letter and replace underscores with spaces
name := strings.ReplaceAll(toolName, "_", " ")
if len(name) > 0 {
return strings.ToUpper(name[:1]) + name[1:]
}
return toolName
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
@@ -179,6 +199,7 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
@@ -204,8 +225,8 @@ func extractBashPrefix(command string) string {
return ""
}
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
@@ -215,19 +236,49 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
continue
}
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
// Normalize to forward slashes for consistent cross-platform matching
arg = strings.ReplaceAll(arg, "\\", "/")
// Security: reject absolute paths
if path.IsAbs(arg) {
return "" // Absolute path - don't create prefix
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
// Normalize the path using stdlib path.Clean (resolves . and ..)
cleaned := path.Clean(arg)
// Security: reject if cleaned path escapes to parent directory
if strings.HasPrefix(cleaned, "..") {
return "" // Path escapes - don't create prefix
}
// Security: if original had "..", verify cleaned path didn't escape to sibling
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
if strings.Contains(arg, "..") {
origBase := strings.SplitN(arg, "/", 2)[0]
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
if origBase != cleanedBase {
return "" // Path escaped to sibling directory
}
}
// Check if arg ends with / (explicit directory)
isDir := strings.HasSuffix(arg, "/")
// Get the directory part
var dir string
if isDir {
dir = cleaned
} else {
dir = path.Dir(cleaned)
}
if dir == "." {
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
return fmt.Sprintf("%s:./", baseCmd)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
@@ -332,6 +383,8 @@ func AllowlistKey(toolName string, args map[string]any) string {
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
@@ -342,12 +395,20 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return true
}
// For bash commands, check prefix matches
// For bash commands, check prefix matches with hierarchical path support
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" && a.prefixes[prefix] {
return true
if prefix != "" {
// Check exact prefix match first
if a.prefixes[prefix] {
return true
}
// Check hierarchical match: if any stored prefix is a parent of current prefix
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
if a.matchesHierarchicalPrefix(prefix) {
return true
}
}
}
}
@@ -360,6 +421,40 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return false
}
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
// Split prefix into command and path parts (format: "cmd:path/")
colonIdx := strings.Index(currentPrefix, ":")
if colonIdx == -1 {
return false
}
currentCmd := currentPrefix[:colonIdx]
currentPath := currentPrefix[colonIdx+1:]
for storedPrefix := range a.prefixes {
storedColonIdx := strings.Index(storedPrefix, ":")
if storedColonIdx == -1 {
continue
}
storedCmd := storedPrefix[:storedColonIdx]
storedPath := storedPrefix[storedColonIdx+1:]
// Commands must match exactly
if currentCmd != storedCmd {
continue
}
// Check if current path starts with stored path (hierarchical match)
// e.g., "tools/subdir/" starts with "tools/"
if strings.HasPrefix(currentPath, storedPath) {
return true
}
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
@@ -399,16 +494,32 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
isWarning = isCommandOutsideCwd(cmd)
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
if prefix := extractBashPrefix(cmd); prefix != "" {
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
}
}
}
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -433,27 +544,29 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
displayName := ToolDisplayName(toolName)
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
}
// For web search, show query
// For web search, show query and internet notice
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s", query))
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
@@ -470,24 +583,28 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
}
// Get terminal size
@@ -647,7 +764,7 @@ func wrapText(text string, maxWidth int) []string {
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
@@ -657,86 +774,70 @@ func getHintLines(state *selectorState) []string {
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
toolLines := strings.Split(state.toolDisplay, "\n")
hintLines := getHintLines(state)
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
warningLines := 0
if state.isWarning {
warningLines = 1
warningLines = 2 // warning line + blank line after
}
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the complete selector box
// renderSelectorBox renders the selector (minimal, no box)
func renderSelectorBox(state *selectorState) {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
toolLines := strings.Split(state.toolDisplay, "\n")
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
// Draw warning line if needed
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
}
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
}
// Draw tool info
// Draw tool info (plain white)
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
}
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option - show with reason input beside it
if i == 2 {
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line before hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw hint (may be multiple lines)
// Draw hint (dark grey)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
@@ -748,50 +849,39 @@ func renderSelectorBox(state *selectorState) {
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + numOptions
// (hint lines - 1) + 1 (blank line) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option
if i == 2 {
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -805,36 +895,26 @@ func updateSelectorOptions(state *selectorState) {
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -858,11 +938,10 @@ func clearSelectorBox(state *selectorState) {
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
var input string
fmt.Scanln(&input)
@@ -905,19 +984,16 @@ func (a *ApprovalManager) AllowedTools() []string {
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var status string
var icon string
var label string
displayName := ToolDisplayName(toolName)
switch result.Decision {
case ApprovalOnce:
status = "Approved"
icon = "\033[32m✓\033[0m"
label = "Approved"
case ApprovalAlways:
status = "Always allowed"
icon = "\033[32m✓\033[0m"
label = "Always allowed"
case ApprovalDeny:
status = "Denied"
icon = "\033[31m✗\033[0m"
label = "Denied"
}
// Format based on tool type
@@ -927,7 +1003,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
}
}
@@ -937,11 +1013,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
}
}
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
}
// FormatDenyResult returns the tool result message when a tool is denied.
@@ -951,3 +1027,78 @@ func FormatDenyResult(toolName string, reason string) string {
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
// Returns true for Yes, false for No.
func PromptYesNo(question string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
selected := 0 // 0 = Yes, 1 = No
options := []string{"Yes", "No"}
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "%s ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
}
}
}
renderYesNo()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return false, err
}
if n == 1 {
switch buf[0] {
case 'y', 'Y':
selected = 0
renderYesNo()
case 'n', 'N':
selected = 1
renderYesNo()
case '\r', '\n': // Enter
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
return selected == 0, nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\r\033[K")
return false, nil
case 27: // Escape - could be arrow key
// Read more bytes for arrow keys
continue
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
// Arrow keys
switch buf[2] {
case 'D': // Left
if selected > 0 {
selected--
}
renderYesNo()
case 'C': // Right
if selected < len(options)-1 {
selected++
}
renderYesNo()
}
}
}
}

View File

@@ -151,6 +151,27 @@ func TestExtractBashPrefix(t *testing.T) {
command: "head -n 100",
expected: "",
},
// Path traversal security tests
{
name: "path traversal - parent escape",
command: "cat tools/../../etc/passwd",
expected: "", // Should NOT create a prefix - path escapes
},
{
name: "path traversal - deep escape",
command: "cat tools/a/b/../../../etc/passwd",
expected: "", // Normalizes to "../etc/passwd" - escapes
},
{
name: "path traversal - absolute path",
command: "cat /etc/passwd",
expected: "", // Absolute paths should not create prefix
},
{
name: "path with safe dotdot - normalized",
command: "cat tools/subdir/../file.go",
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
},
}
for _, tt := range tests {
@@ -164,6 +185,34 @@ func TestExtractBashPrefix(t *testing.T) {
}
}
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Path traversal attack: should NOT be allowed
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
t.Error("SECURITY: path traversal attack should NOT be allowed")
}
// Another traversal variant
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
t.Error("SECURITY: deep path traversal should NOT be allowed")
}
// Valid subdirectory access should still work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed")
}
// Safe ".." that normalizes to within allowed directory should work
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
@@ -186,6 +235,119 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow subdirectories (hierarchical matching)
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
}
// Should allow deeply nested subdirectories
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
}
// Should still allow same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
t.Error("expected cat tools/another.go to be allowed")
}
// Should NOT allow different base directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should NOT allow different command even in subdirectory
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
}
// Should NOT allow similar but different directory name
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
am := NewApprovalManager()
// Allow with forward slashes (Unix-style)
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should work with backslashes too (Windows-style) - normalized internally
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
}
// Mixed slashes should also work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
}
}
func TestMatchesHierarchicalPrefix(t *testing.T) {
am := NewApprovalManager()
// Add prefix for "cat:tools/"
am.prefixes["cat:tools/"] = true
tests := []struct {
name string
prefix string
expected bool
}{
{
name: "exact match",
prefix: "cat:tools/",
expected: true, // exact match also passes HasPrefix - caller handles exact match first
},
{
name: "subdirectory",
prefix: "cat:tools/subdir/",
expected: true,
},
{
name: "deeply nested",
prefix: "cat:tools/a/b/c/",
expected: true,
},
{
name: "different base directory",
prefix: "cat:src/",
expected: false,
},
{
name: "different command same path",
prefix: "ls:tools/",
expected: false,
},
{
name: "similar directory name",
prefix: "cat:toolsbin/",
expected: false,
},
{
name: "invalid prefix format",
prefix: "cattools",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := am.matchesHierarchicalPrefix(tt.prefix)
if result != tt.expected {
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
tt.prefix, result, tt.expected)
}
})
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,10 +6,12 @@ import (
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -22,6 +24,101 @@ import (
"github.com/ollama/ollama/x/tools"
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
localModelTokenLimit = 4000
// defaultTokenLimit is the token limit for cloud/remote models.
defaultTokenLimit = 10000
// charsPerToken is a rough estimate of characters per token.
// TODO: Estimate tokens more accurately using tokenizer if available
charsPerToken = 4
)
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
}
// isLocalServer checks if connecting to a local Ollama server.
// TODO: Could also check other indicators of local vs cloud server
func isLocalServer() bool {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return true // Default is localhost:11434
}
// Parse the URL to check host
parsed, err := url.Parse(host)
if err != nil {
return true // If can't parse, assume local
}
hostname := parsed.Hostname()
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
var tokenLimit int
if isLocalModel(modelName) && isLocalServer() {
tokenLimit = localModelTokenLimit
} else {
tokenLimit = defaultTokenLimit
}
maxChars := tokenLimit * charsPerToken
if len(output) > maxChars {
return output[:maxChars] + "\n... (output truncated)"
}
return output
}
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
func waitForOllamaSignin(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\n")
return ctx.Err()
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
return nil
}
// Still waiting, show dot
fmt.Fprintf(os.Stderr, ".")
}
}
}
return err
}
return nil
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -37,6 +134,9 @@ type RunOptions struct {
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
// YoloMode skips all tool approval prompts
YoloMode bool
}
// Chat runs an agent chat loop with tool support.
@@ -77,6 +177,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
role := "assistant"
messages := opts.Messages
@@ -159,6 +260,58 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
continue // Retry the loop
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
}
// Check for 500 errors (often tool parsing failures) - inform the model
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
consecutiveErrors++
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
if assistantContent == "" {
assistantContent = "(empty response)"
}
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
messages = append(messages,
api.Message{Role: "user", Content: errorMsg},
)
// Reset state and retry
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
continue
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
@@ -168,6 +321,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, err
}
// Reset consecutive error counter on success
consecutiveErrors = 0
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
@@ -197,8 +353,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
@@ -208,15 +364,21 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check if command is auto-allowed (safe command)
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
// TODO(parthsareen): re-enable with tighter scoped allowlist
// if agent.IsAutoAllowed(cmd) {
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
// skipApproval = true
// }
}
}
// Check approval (uses prefix matching for bash commands)
if !skipApproval && !approval.IsAllowed(toolName, args) {
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
@@ -244,13 +406,30 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
}
}
}
}
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
@@ -258,6 +437,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
})
continue
}
toolSuccess:
// Display tool output (truncated for display)
if toolResult != "" {
@@ -269,9 +449,12 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResult,
Content: toolResultForLLM,
ToolCallID: call.ID,
})
}
@@ -317,17 +500,18 @@ func truncateUTF8(s string, limit int) string {
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
displayName := agent.ToolDisplayName(toolName)
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
}
}
return toolName
return displayName
}
// Helper types and functions for display
@@ -449,7 +633,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -466,7 +651,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
supportsTools = false
}
@@ -474,14 +659,17 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
@@ -524,6 +712,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
@@ -546,6 +737,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
}
assistant, err := Chat(cmd.Context(), opts)

180
x/cmd/run_test.go Normal file
View File

@@ -0,0 +1,180 @@
package cmd
import (
"testing"
)
func TestIsLocalModel(t *testing.T) {
tests := []struct {
name string
modelName string
expected bool
}{
{
name: "local model without suffix",
modelName: "llama3.2",
expected: true,
},
{
name: "local model with version",
modelName: "qwen2.5:7b",
expected: true,
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
expected: false,
},
{
name: "empty model name",
modelName: "",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalModel(tt.modelName)
if result != tt.expected {
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
}
})
}
}
func TestIsLocalServer(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{
name: "empty host (default)",
host: "",
expected: true,
},
{
name: "localhost",
host: "http://localhost:11434",
expected: true,
},
{
name: "127.0.0.1",
host: "http://127.0.0.1:11434",
expected: true,
},
{
name: "custom port on localhost",
host: "http://localhost:8080",
expected: true, // localhost is always considered local
},
{
name: "remote host",
host: "http://ollama.example.com:11434",
expected: true, // has :11434
},
{
name: "remote host different port",
host: "http://ollama.example.com:8080",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := isLocalServer()
if result != tt.expected {
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
}
})
}
}
func TestTruncateToolOutput(t *testing.T) {
// Create outputs of different sizes
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
for i := range localLimitOutput {
localLimitOutput[i] = 'a'
}
for i := range defaultLimitOutput {
defaultLimitOutput[i] = 'b'
}
tests := []struct {
name string
output string
modelName string
host string
shouldTrim bool
expectedLimit int
}{
{
name: "short output local model",
output: "hello world",
modelName: "llama3.2",
host: "",
shouldTrim: false,
expectedLimit: localModelTokenLimit,
},
{
name: "long output local model - trimmed at 4k",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "",
shouldTrim: true,
expectedLimit: localModelTokenLimit,
},
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,
},
{
name: "long output remote server - uses 10k limit",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "http://remote.example.com:8080",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := truncateToolOutput(tt.output, tt.modelName)
if tt.shouldTrim {
maxLen := tt.expectedLimit * charsPerToken
if len(result) > maxLen+50 { // +50 for the truncation message
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
}
if result == tt.output {
t.Error("expected output to be truncated but it wasn't")
}
} else {
if result != tt.output {
t.Error("expected output to not be truncated")
}
}
})
}
}

185
x/grammar/README.md Normal file
View File

@@ -0,0 +1,185 @@
# grammar
Grammar-constrained decoding for LLM outputs using MLX.
## Performance
Performance depends on hardware, vocabulary size, grammar, and whether you
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
setup.
### Design choices that keep masking fast
| Technique | Impact |
|-----------|--------|
| Precomputed token analysis | Terminal matches computed once at startup |
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
| Partitioned tokens | Exact matches separated from DP candidates |
### Comparison Notes
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
## Usage
```go
import (
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
)
// Use built-in JSON grammar
g, _ := grammar.JSONGrammar()
// Or from JSON Schema (OpenAI-compatible)
g, _ := schema.Grammar(`{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`)
// Or parse custom EBNF
g, _ := grammar.ParseEBNF(myGrammar, "root")
// Create engine with model vocabulary
engine, _ := grammar.NewEngine(g, vocab)
defer engine.Close()
// Generation loop
for !engine.IsComplete() {
logits := model.Forward(tokens)
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
nextToken := sample(masked)
engine.Accept(nextToken)
}
// Output conforms to the grammar when you only sample from masked tokens and call Accept
```
## EBNF Syntax
```ebnf
rule = expression . # Rule definition (ends with .)
"literal" # Literal string
"a" "z" # Character range (inclusive)
( a | b ) # Grouping with alternation
[ optional ] # Optional (0 or 1)
{ repeated } # Repetition (0 or more)
```
### Example: JSON Grammar
```ebnf
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .
```
### Example: Custom Schema
```ebnf
root = "{" ws name_field "," ws age_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
string = "\"" { char } "\"" .
char = " " | "!" | "#" "~" .
number = [ "-" ] digit { digit } .
digit = "0" "9" .
ws = { " " | "\n" } .
```
## JSON Schema Support
OpenAI-compatible JSON Schema support with automatic EBNF generation:
```go
schema := `{
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"required": ["user"],
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string", "format": "email"},
"role": {"enum": ["admin", "user", "guest"]}
},
"required": ["name", "email", "role"]
}
}
}`
grammar, _ := schema.Grammar(schema)
```
### Supported Features
| Feature | Example |
|---------|---------|
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
| Objects | `properties`, `required` |
| Arrays | `items`, `minItems`, `maxItems` |
| Enums | `enum: ["a", "b", "c"]` |
| Constants | `const: "value"` |
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
| References | `$ref: "#/$defs/Name"`, `$defs` |
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
## Benchmarks
```bash
# Run all tests
go test -tags mlx ./x/grammar/...
# Run benchmarks
go test -tags mlx ./x/grammar/ -bench=.
# Compare with llama.cpp (outputs JSON)
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
# Compare with a more complex schema
go run -tags mlx ./x/grammar/cmd/compare \
-gbnf x/grammar/cmd/compare/complex.gbnf \
-schema x/grammar/cmd/compare/complex.schema.json \
-vocab-size 128000 -iterations 500
```
## References
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs

161
x/grammar/analyzer.go Normal file
View File

@@ -0,0 +1,161 @@
//go:build mlx
package grammar
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
type terminalTokenGroups struct {
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
ExactMatches []int32
// DPCandidates are tokens that start with this terminal but need DP validation
DPCandidates []int
}
// tokenAnalysis contains precomputed terminal matches for a token
type tokenAnalysis struct {
// The token string
Token string
// TokenID in the vocabulary
TokenID int
// Matches at each byte position
// MatchesAtPos[i] = terminals matching at position i with their lengths
MatchesAtPos [][]terminalMatch
// Fast path: if token exactly matches one terminal
// -1 if no exact match
exactMatch int
// Whether this token can be consumed at all (has at least one match)
HasMatches bool
}
// analyzer precomputes terminal matches for a vocabulary
type analyzer struct {
matcher *terminalMatcher
analyses []tokenAnalysis // Indexed by token ID
vocab []string
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
// This enables direct slice appends instead of per-token branching
tokensByTerminal []terminalTokenGroups
}
// newAnalyzer creates an analyzer for the given vocabulary and terminals
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
a := &analyzer{
matcher: matcher,
analyses: make([]tokenAnalysis, len(vocab)),
vocab: vocab,
}
// Precompute analysis for each token
for i, token := range vocab {
a.analyses[i] = a.analyze(token, i)
}
// Build pre-partitioned token groups for fast ApplyMask
a.buildTokenPartitions()
return a
}
// analyze computes terminal matches for a single token
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
analysis := tokenAnalysis{
Token: token,
TokenID: tokenID,
MatchesAtPos: make([][]terminalMatch, len(token)),
exactMatch: -1,
HasMatches: false,
}
if len(token) == 0 {
return analysis
}
// Compute matches at each position
data := []byte(token)
for pos := 0; pos < len(data); pos++ {
matches := a.matcher.matchesAt(data, pos)
analysis.MatchesAtPos[pos] = matches
if len(matches) > 0 {
analysis.HasMatches = true
}
}
// Exact match is only valid when a single terminal spans the entire token
if len(analysis.MatchesAtPos) > 0 {
var exactID int = -1
for _, match := range analysis.MatchesAtPos[0] {
if match.Length != len(token) {
continue
}
if exactID >= 0 && exactID != match.TerminalID {
exactID = -1
break
}
exactID = match.TerminalID
}
analysis.exactMatch = exactID
}
return analysis
}
// analysis returns the precomputed analysis for a token ID
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
if tokenID < 0 || tokenID >= len(a.analyses) {
return tokenAnalysis{exactMatch: -1}
}
return a.analyses[tokenID]
}
// vocabSize returns the vocabulary size
func (a *analyzer) vocabSize() int {
return len(a.vocab)
}
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
// This enables ApplyMask to use direct slice appends instead of per-token branching.
func (a *analyzer) buildTokenPartitions() {
numTerminals := a.matcher.terminalCount()
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
for tokenID, analysis := range a.analyses {
if !analysis.HasMatches {
continue
}
if analysis.exactMatch >= 0 {
// Token exactly matches one terminal - fast path (O(1) validation)
tid := analysis.exactMatch
a.tokensByTerminal[tid].ExactMatches = append(
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
} else {
// Token needs DP validation - add to all terminals it can start with
// This way, when a terminal is valid, we know exactly which tokens need DP
if len(analysis.MatchesAtPos) > 0 {
seen := make(map[int]bool)
for _, match := range analysis.MatchesAtPos[0] {
tid := match.TerminalID
if !seen[tid] {
seen[tid] = true
a.tokensByTerminal[tid].DPCandidates = append(
a.tokensByTerminal[tid].DPCandidates, tokenID)
}
}
}
}
}
}
// terminalGroups returns the pre-partitioned token groups for a terminal ID
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
return terminalTokenGroups{}
}
return a.tokensByTerminal[terminalID]
}

648
x/grammar/bridge.go Normal file
View File

@@ -0,0 +1,648 @@
//go:build mlx
package grammar
import (
"encoding/binary"
"hash/fnv"
"sort"
"sync"
)
// visitedMapPool reduces allocations for visited maps in bridge operations
var visitedMapPool = sync.Pool{
New: func() interface{} {
return make(map[stateStackKey]bool, 16)
},
}
// getVisitedMap gets a map from the pool
func getVisitedMap() map[stateStackKey]bool {
return visitedMapPool.Get().(map[stateStackKey]bool)
}
// putVisitedMap returns a map to the pool after clearing it
func putVisitedMap(m map[stateStackKey]bool) {
for k := range m {
delete(m, k)
}
visitedMapPool.Put(m)
}
// parserConfig represents a pda state+stack combination
type parserConfig struct {
state state
Stack []stackSymbol
}
// clone creates a deep copy of the config
func (c *parserConfig) clone() *parserConfig {
newStack := make([]stackSymbol, len(c.Stack))
copy(newStack, c.Stack)
return &parserConfig{
state: c.state,
Stack: newStack,
}
}
// key returns a unique key for this config for deduplication
func (c *parserConfig) key() uint64 {
h := fnv.New64a()
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
h.Write(buf[:])
for _, sym := range c.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
return h.Sum64()
}
// configSet represents a set of parser configurations (for nondeterminism)
type configSet struct {
configs []*parserConfig
normalized bool // true if already deduplicated and sorted
cachedSig uint64 // cached signature after normalization
}
// newConfigSet creates a new config set with a single configuration
func newConfigSet(state state, stack []stackSymbol) *configSet {
return &configSet{
configs: []*parserConfig{
{state: state, Stack: stack},
},
normalized: true, // single config is already normalized
}
}
// normalize deduplicates and sorts configs for stable signatures
func (c *configSet) normalize() {
if c.normalized || len(c.configs) <= 1 {
c.normalized = true
return
}
// Deduplicate using a map
seen := make(map[uint64]*parserConfig, len(c.configs))
for _, cfg := range c.configs {
key := cfg.key()
if _, exists := seen[key]; !exists {
seen[key] = cfg
}
}
// Extract unique configs
unique := make([]*parserConfig, 0, len(seen))
for _, cfg := range seen {
unique = append(unique, cfg)
}
// Sort by key for deterministic ordering
sort.Slice(unique, func(i, j int) bool {
return unique[i].key() < unique[j].key()
})
c.configs = unique
c.normalized = true
}
// signature returns a hash for cache lookup (normalizes first)
func (c *configSet) signature() uint64 {
c.normalize()
// Return cached signature if available
if c.cachedSig != 0 {
return c.cachedSig
}
h := fnv.New64a()
// Hash number of configs
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
h.Write(buf[:])
// Hash each config (already sorted)
for _, cfg := range c.configs {
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
h.Write(buf[:])
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
h.Write(buf[:])
for _, sym := range cfg.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
}
c.cachedSig = h.Sum64()
return c.cachedSig
}
// isEmpty returns true if there are no configurations
func (c *configSet) isEmpty() bool {
return len(c.configs) == 0
}
// clone creates a deep copy of the config set
func (c *configSet) clone() *configSet {
newConfigs := make([]*parserConfig, len(c.configs))
for i, cfg := range c.configs {
newConfigs[i] = cfg.clone()
}
return &configSet{configs: newConfigs}
}
// bridge connects token analysis to pda validation
type bridge struct {
pda *pda
analyzer *analyzer
}
// newBridge creates a new bridge
func newBridge(pda *pda, analyzer *analyzer) *bridge {
return &bridge{
pda: pda,
analyzer: analyzer,
}
}
// IsTokenValid checks if token T can be consumed from the current config
// This is the main entry point for token validation
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return false
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
return b.canAcceptTerminal(config, terminal.Pattern)
}
// General path: DP over (pos, config)
return b.dpValidate(&analysis, config)
}
// canAcceptTerminal checks if any config can accept the terminal
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
for _, cfg := range config.configs {
if b.canConfigAcceptTerminal(cfg, pattern) {
return true
}
}
return false
}
// canConfigAcceptTerminal checks if a single config can accept the terminal
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
// Use pooled visited map to reduce allocations
visited := getVisitedMap()
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
putVisitedMap(visited)
return result
}
// tryAcceptTerminal recursively tries to accept a terminal from a state
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Direct match
return true
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
// Pop
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
return true
}
}
}
return false
}
// dpValidate runs DP for multi-terminal tokens
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
// state: (pos, configSet)
// Memoize by (pos, configSig)
type dpKey struct {
pos int
sig uint64
}
memo := make(map[dpKey]bool)
var dp func(pos int, config *configSet) bool
dp = func(pos int, config *configSet) bool {
if pos == len(analysis.Token) {
return true // Consumed entire token
}
if config.isEmpty() {
return false
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
memo[key] = true
return true
}
}
memo[key] = false
return false
}
return dp(0, startConfig)
}
// advanceConfig advances all configs that can accept the terminal
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
var newConfigs []*parserConfig
for _, cfg := range config.configs {
advanced := b.advanceSingleConfig(cfg, pattern)
newConfigs = append(newConfigs, advanced...)
}
if len(newConfigs) == 0 {
return nil
}
return &configSet{configs: newConfigs}
}
// advanceSingleConfig advances a single config by accepting a terminal
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
var results []*parserConfig
visited := getVisitedMap()
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
putVisitedMap(visited)
return results
}
// collectAdvanced collects all configs reachable by accepting the pattern
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Match! Create new config after transition
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
*results = append(*results, &parserConfig{
state: t.ToState,
Stack: newStack,
})
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
}
}
}
// validTokens returns all token IDs that are valid from the given config
func (b *bridge) validTokens(config *configSet) []int {
var valid []int
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
if b.IsTokenValid(tokenID, config) {
valid = append(valid, tokenID)
}
}
return valid
}
// acceptToken attempts to accept a token and returns the new config set
// Returns nil if the token is not valid from this config
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return nil
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
newConfig.normalize()
return newConfig
}
return nil
}
// General path: DP to find final config after consuming token
return b.dpAccept(&analysis, config)
}
// dpAccept runs DP to accept a multi-terminal token and return final config
// Returns the union of all possible end configurations (preserves nondeterminism)
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
type dpKey struct {
pos int
sig uint64
}
// Memoize the configs reachable at each (pos, sig)
memo := make(map[dpKey]*configSet)
var dp func(pos int, config *configSet) *configSet
dp = func(pos int, config *configSet) *configSet {
if pos == len(analysis.Token) {
return config // Consumed entire token, return final config
}
if config.isEmpty() {
return nil
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Collect all valid result configs from all possible paths
var allConfigs []*parserConfig
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
finalConfig := dp(pos+match.Length, newConfig)
if finalConfig != nil {
// Collect all configs, don't return early
allConfigs = append(allConfigs, finalConfig.configs...)
}
}
}
// Build result: nil if no valid paths, normalized configSet otherwise
var result *configSet
if len(allConfigs) > 0 {
result = &configSet{configs: allConfigs}
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
}
memo[key] = result // Cache normalized result
return result
}
return dp(0, startConfig)
}
// isAccepting returns true if any config can reach an accepting state
func (b *bridge) isAccepting(config *configSet) bool {
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config check
for k := range visited {
delete(visited, k)
}
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
return true
}
}
return false
}
// canReachAccept checks if we can reach an accepting state via epsilon transitions
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if b.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if b.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// validTerminals returns the valid terminal patterns from the given config
func (b *bridge) validTerminals(config *configSet) []string {
seen := make(map[string]bool)
var terminals []string
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
}
return terminals
}
// collectValidTerminals collects all reachable terminals
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" && !seen[t.Pattern] {
seen[t.Pattern] = true
*terminals = append(*terminals, t.Pattern)
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
}
}
}
// validTerminalIDs returns the IDs of valid terminals from the given config
func (b *bridge) validTerminalIDs(config *configSet) []int {
seen := make(map[int]bool)
var terminalIDs []int
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
}
return terminalIDs
}
// collectValidTerminalIDs collects IDs of all reachable terminals
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" {
// Look up terminal ID from pattern
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
seen[tid] = true
*terminalIDs = append(*terminalIDs, tid)
}
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
}
}
}

View File

@@ -0,0 +1,45 @@
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
id-field ::= "\"id\"" ws ":" ws uuid
kind-field ::= "\"kind\"" ws ":" ws kind
items-field ::= "\"items\"" ws ":" ws items
alt-field ::= "\"alt\"" ws ":" ws alt
flags-field ::= "\"flags\"" ws ":" ws flags
meta-field ::= "\"meta\"" ws ":" ws meta
priority-field ::= "\"priority\"" ws ":" ws int
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
source ::= "\"api\"" | "\"batch\"" | "\"import\""
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
item-sku ::= "\"sku\"" ws ":" ws string
item-qty ::= "\"qty\"" ws ":" ws int
item-status ::= "\"status\"" ws ":" ws status
item-notes ::= "\"notes\"" ws ":" ws string
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
meta-created ::= "\"created\"" ws ":" ws date-time
meta-source ::= "\"source\"" ws ":" ws source
meta-ip ::= "\"ip\"" ws ":" ws ipv4
alt ::= string | int | "null"
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
int ::= "-"? digit+
digit ::= [0-9]
hex ::= [0-9a-fA-F]
ws ::= [ \t\n\r]*

View File

@@ -0,0 +1,46 @@
{
"type": "object",
"properties": {
"id": { "type": "string", "format": "uuid" },
"kind": { "enum": ["order", "invoice", "shipment"] },
"items": {
"type": "array",
"minItems": 1,
"maxItems": 3,
"items": {
"type": "object",
"properties": {
"sku": { "type": "string" },
"qty": { "type": "integer" },
"status": { "enum": ["new", "backorder", "shipped"] },
"notes": { "type": "string" }
},
"required": ["sku", "qty", "status", "notes"]
}
},
"alt": {
"oneOf": [
{ "type": "string" },
{ "type": "null" },
{ "type": "integer" }
]
},
"flags": {
"type": "array",
"minItems": 0,
"maxItems": 4,
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
},
"meta": {
"type": "object",
"properties": {
"created": { "type": "string", "format": "date-time" },
"source": { "enum": ["api", "batch", "import"] },
"ip": { "type": "string", "format": "ipv4" }
},
"required": ["created", "source", "ip"]
},
"priority": { "type": "integer" }
},
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
}

View File

@@ -0,0 +1,235 @@
//go:build mlx
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"time"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
"github.com/ollama/ollama/x/imagegen/mlx"
)
const jsonGBNF = `
root ::= value
value ::= object | array | string | number | "true" | "false" | "null"
object ::= "{" ws "}" | "{" members "}"
members ::= member ("," member)*
member ::= ws string ws ":" element
array ::= "[" ws "]" | "[" elements "]"
elements ::= element ("," element)*
element ::= ws value ws
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
number ::= "-"? integer fraction? exponent?
integer ::= "0" | [1-9] [0-9]*
fraction ::= "." [0-9]+
exponent ::= [eE] [+-]? [0-9]+
ws ::= [ \t\n\r]*
`
type result struct {
vocabSize int `json:"vocab_size"`
Iterations int `json:"iterations"`
Warmup int `json:"warmup"`
ConstrainedSource string `json:"constrained_source"`
LlamaSource string `json:"llama_source"`
LlamaApply string `json:"llama_apply"`
ConstrainedGraph string `json:"constrained_graph"`
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
EvalOnly string `json:"eval_only,omitempty"`
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
}
func main() {
var (
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
iterations = flag.Int("iterations", 500, "Benchmark iterations")
warmup = flag.Int("warmup", 50, "Warmup iterations")
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
startRule = flag.String("start", "root", "Start rule for EBNF")
)
flag.Parse()
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
fmt.Fprintln(os.Stderr, "invalid flags")
os.Exit(2)
}
vocab := createVocab(*vocabSize)
if *schemaPath != "" && *ebnfPath != "" {
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
os.Exit(2)
}
var constrainedSource string
var compiled *grammar.Grammar
var err error
switch {
case *schemaPath != "":
data, readErr := os.ReadFile(*schemaPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
os.Exit(1)
}
compiled, err = schema.Grammar(string(data))
constrainedSource = "schema:" + *schemaPath
case *ebnfPath != "":
data, readErr := os.ReadFile(*ebnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
os.Exit(1)
}
compiled, err = grammar.ParseEBNF(string(data), *startRule)
constrainedSource = "ebnf:" + *ebnfPath
default:
compiled, err = grammar.JSONGrammar()
constrainedSource = "json"
}
if err != nil {
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
os.Exit(1)
}
engine, err := grammar.NewEngine(compiled, vocab)
if err != nil {
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
os.Exit(1)
}
defer engine.Close()
logits := mlx.Ones(int32(*vocabSize))
mlx.Keep(logits)
for i := 0; i < *warmup; i++ {
masked := engine.ApplyMask(logits)
if *withEval {
mlx.Eval(masked)
}
}
graphAvg := measure(*iterations, func() {
_ = engine.ApplyMask(logits)
})
var evalAvg time.Duration
var evalOnlyAvg time.Duration
if *withEval {
evalOnlyAvg = measure(*iterations, func() {
baseline := mlx.MulScalar(logits, 1)
mlx.Eval(baseline)
baseline.Free()
})
evalAvg = measure(*iterations, func() {
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
})
}
vocabIDs := make([]uint32, *vocabSize)
for i := range vocabIDs {
vocabIDs[i] = uint32(i)
}
eogTokens := []int32{0}
gbnf := jsonGBNF
llamaSource := "json"
if *gbnfPath != "" {
data, readErr := os.ReadFile(*gbnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
os.Exit(1)
}
gbnf = string(data)
llamaSource = *gbnfPath
}
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
if llamaGrammar == nil {
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
os.Exit(1)
}
defer llamaGrammar.Free()
llamaTokens := make([]llama.TokenData, *vocabSize)
for i := 0; i < *warmup; i++ {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
}
llamaAvg := measure(*iterations, func() {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
})
out := result{
vocabSize: *vocabSize,
Iterations: *iterations,
Warmup: *warmup,
LlamaApply: llamaAvg.String(),
ConstrainedGraph: graphAvg.String(),
ConstrainedSource: constrainedSource,
LlamaSource: llamaSource,
}
if *withEval {
out.ConstrainedWithEval = evalAvg.String()
out.EvalOnly = evalOnlyAvg.String()
if evalAvg > evalOnlyAvg {
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
} else {
out.ConstrainedEvalNet = "0s"
}
}
enc := json.NewEncoder(os.Stdout)
if err := enc.Encode(out); err != nil {
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
os.Exit(1)
}
}
func measure(iterations int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < iterations; i++ {
fn()
}
return time.Since(start) / time.Duration(iterations)
}
func createVocab(size int) []string {
vocab := make([]string, size)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"",
}
for i, t := range jsonTokens {
if i < size {
vocab[i] = t
}
}
for i := len(jsonTokens); i < size; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}

320
x/grammar/compiled.go Normal file
View File

@@ -0,0 +1,320 @@
//go:build mlx
package grammar
import (
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// Grammar is the compiled form of an EBNF grammar.
// It contains terminals, parse tables, and the start state.
// Use ParseEBNF or JSONGrammar to create a Grammar.
type Grammar struct {
// The underlying pda
pda *pda
// Compiled terminal matcher
matcher *terminalMatcher
}
// ParseEBNF compiles an EBNF grammar string into a Grammar.
// startRule is the name of the start rule (e.g., "root", "json").
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
pda, err := compileString(ebnf, startRule)
if err != nil {
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
}
matcher, err := compileTerminalsStrict(pda)
if err != nil {
return nil, fmt.Errorf("failed to compile terminals: %w", err)
}
return &Grammar{
pda: pda,
matcher: matcher,
}, nil
}
// JSONGrammar returns the compiled JSON grammar.
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
func JSONGrammar() (*Grammar, error) {
return ParseEBNF(JSONGrammarEBNF, "json")
}
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
// Use this when you want to ensure the output is a JSON object (starts with {).
func JSONObjectGrammar() (*Grammar, error) {
return ParseEBNF(JSONObjectGrammarEBNF, "json")
}
// compileTerminalsStrict builds a matcher that properly handles:
// - Escaped literals ("\n", \"", \uXXXX)
// - Unicode ranges (rune-based, not byte-based)
// - Rejects unsupported patterns with an error (no silent fallback)
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
m := &terminalMatcher{
literalTrie: &trieNode{terminalID: -1},
ranges: make([]terminal, 0),
terminals: make([]terminal, 0, len(pda.Terminals)),
patternToID: make(map[string]int),
}
// Track which pattern produced each unescaped value for collision detection
unescapedSource := make(map[string]string) // unescaped -> original pattern
for i, pattern := range pda.Terminals {
terminal, err := parseTerminalPattern(pattern, i)
if err != nil {
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
}
if terminal.Type == terminalLiteral {
// Use the unescaped pattern for trie matching
m.addLiteralToTrie(terminal.Unescaped, i)
// Detect collisions between literals that unescape to the same value
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
if existingPattern != pattern {
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
existingPattern, pattern, terminal.Unescaped)
}
} else {
unescapedSource[terminal.Unescaped] = pattern
}
} else if terminal.Type == terminalRange {
m.ranges = append(m.ranges, terminal)
}
m.terminals = append(m.terminals, terminal)
m.patternToID[pattern] = i
}
return m, nil
}
// parseTerminalPattern parses a terminal pattern and returns a terminal.
// Supports:
// - Literal strings (with escape sequences)
// - Character ranges [X-Y] (unicode-aware)
func parseTerminalPattern(pattern string, id int) (terminal, error) {
if len(pattern) == 0 {
return terminal{}, fmt.Errorf("empty pattern")
}
// Check for range pattern: [X-Y]
if isUnicodeRangePattern(pattern) {
lowRune, highRune, err := parseUnicodeRange(pattern)
if err != nil {
return terminal{}, err
}
return terminal{
ID: id,
Type: terminalRange,
Pattern: pattern,
Unescaped: pattern,
LowRune: lowRune,
HighRune: highRune,
}, nil
}
// It's a literal - unescape it
unescaped, err := unescapeLiteral(pattern)
if err != nil {
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
}
return terminal{
ID: id,
Type: terminalLiteral,
Pattern: pattern,
Unescaped: unescaped,
}, nil
}
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
func isUnicodeRangePattern(pattern string) bool {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return false
}
// Find the dash that separates low-high
inner := pattern[1 : len(pattern)-1]
dashIdx := strings.Index(inner, "-")
// Handle escaped dash at start
if dashIdx <= 0 {
return false
}
return true
}
// parseUnicodeRange parses [X-Y] into low and high runes
func parseUnicodeRange(pattern string) (rune, rune, error) {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return 0, 0, fmt.Errorf("invalid range pattern")
}
inner := pattern[1 : len(pattern)-1]
// Simple case: [a-z] where a and z are single chars
if len(inner) == 3 && inner[1] == '-' {
return rune(inner[0]), rune(inner[2]), nil
}
// Handle escaped characters like [\u0000-\uFFFF]
dashIdx := findRangeDash(inner)
if dashIdx < 0 {
return 0, 0, fmt.Errorf("no dash in range")
}
lowStr := inner[:dashIdx]
highStr := inner[dashIdx+1:]
lowRune, err := parseRune(lowStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
}
highRune, err := parseRune(highStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
}
if lowRune > highRune {
return 0, 0, fmt.Errorf("low bound > high bound")
}
return lowRune, highRune, nil
}
// findRangeDash finds the dash separating low-high in a range pattern
func findRangeDash(inner string) int {
i := 0
for i < len(inner) {
if inner[i] == '\\' && i+1 < len(inner) {
// Skip escape sequence
if inner[i+1] == 'u' && i+6 <= len(inner) {
i += 6 // \uXXXX
} else {
i += 2 // \n, \t, etc.
}
continue
}
if inner[i] == '-' && i > 0 {
return i
}
i++
}
return -1
}
// parseRune parses a single rune from a string (handles escapes)
func parseRune(s string) (rune, error) {
if len(s) == 0 {
return 0, fmt.Errorf("empty rune")
}
// Handle escape sequences
if s[0] == '\\' {
if len(s) < 2 {
return 0, fmt.Errorf("incomplete escape")
}
switch s[1] {
case 'n':
return '\n', nil
case 't':
return '\t', nil
case 'r':
return '\r', nil
case '\\':
return '\\', nil
case '"':
return '"', nil
case '\'':
return '\'', nil
case 'u':
if len(s) < 6 {
return 0, fmt.Errorf("incomplete unicode escape")
}
val, err := strconv.ParseInt(s[2:6], 16, 32)
if err != nil {
return 0, fmt.Errorf("invalid unicode escape: %w", err)
}
return rune(val), nil
default:
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
}
}
// Plain character
r, _ := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return 0, fmt.Errorf("invalid utf8")
}
return r, nil
}
// unescapeLiteral unescapes a literal pattern string
func unescapeLiteral(pattern string) (string, error) {
// Try strconv.Unquote if it looks quoted
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
unquoted, err := strconv.Unquote(pattern)
if err != nil {
return "", err
}
return unquoted, nil
}
// If no backslashes, return as-is
if !strings.Contains(pattern, "\\") {
return pattern, nil
}
// Manual unescape
var result strings.Builder
i := 0
for i < len(pattern) {
if pattern[i] == '\\' && i+1 < len(pattern) {
switch pattern[i+1] {
case 'n':
result.WriteByte('\n')
i += 2
case 't':
result.WriteByte('\t')
i += 2
case 'r':
result.WriteByte('\r')
i += 2
case '\\':
result.WriteByte('\\')
i += 2
case '"':
result.WriteByte('"')
i += 2
case '\'':
result.WriteByte('\'')
i += 2
case 'u':
if i+6 <= len(pattern) {
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
if err != nil {
return "", fmt.Errorf("invalid unicode escape at %d", i)
}
result.WriteRune(rune(val))
i += 6
} else {
return "", fmt.Errorf("incomplete unicode escape at %d", i)
}
default:
// Reject unknown escape sequences
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
}
} else {
result.WriteByte(pattern[i])
i++
}
}
return result.String(), nil
}

329
x/grammar/engine.go Normal file
View File

@@ -0,0 +1,329 @@
//go:build mlx
package grammar
import (
"container/list"
"fmt"
"math"
"sync"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// maskCache provides LRU caching for computed masks.
type maskCache struct {
cache map[uint64]*list.Element
order *list.List
maxSize int
mu sync.Mutex
}
type maskEntry struct {
sig uint64
mask *mlx.Array
}
// newMaskCache creates a new mask cache with the given max size
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
func newMaskCache(maxSize int) *maskCache {
if maxSize <= 0 {
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: 0, // Signals disabled
}
}
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: maxSize,
}
}
// get retrieves a cached mask, returning nil if not found.
// Updates LRU order on cache hit.
func (c *maskCache) get(sig uint64) *mlx.Array {
if c.maxSize <= 0 {
return nil // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, ok := c.cache[sig]; ok {
c.order.MoveToFront(elem)
return elem.Value.(*maskEntry).mask
}
return nil
}
// put stores a mask in the cache with LRU eviction.
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
if c.maxSize <= 0 {
return // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, exists := c.cache[sig]; exists {
c.order.MoveToFront(elem)
return
}
// Evict oldest if at capacity (safe since maxSize > 0)
if c.order.Len() >= c.maxSize {
oldest := c.order.Back()
if oldest != nil {
entry := oldest.Value.(*maskEntry)
entry.mask.Free()
delete(c.cache, entry.sig)
c.order.Remove(oldest)
}
}
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
c.cache[sig] = elem
}
// clear frees all cached masks.
func (c *maskCache) clear() {
c.mu.Lock()
defer c.mu.Unlock()
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
elem.Value.(*maskEntry).mask.Free()
}
c.cache = make(map[uint64]*list.Element)
c.order.Init()
}
// size returns the number of cached masks.
func (c *maskCache) size() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.cache)
}
// Engine applies grammar constraints to model outputs using MLX.
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
type Engine struct {
// The compiled grammar
grammar *Grammar
// bridge for token validation
bridge *bridge
analyzer *analyzer
// Current parser state (configSet for nondeterminism)
configSet *configSet
// Token vocabulary from the model
vocab []string
tokenToID map[string]int // O(1) lookup for AcceptString
// Mask cache: configSig → valid token mask (LRU)
maskCache *maskCache
// Cached negative infinity mask for invalid tokens
negInfMask *mlx.Array
// Threshold for comparison (0.5 since mask values are 0 or 1)
threshold *mlx.Array
// Vocabulary size
vocabSize int32
// Reusable buffers for candidate filtering (avoid allocations)
candidateMark []bool // indexed by tokenID, true if in candidate set
touched []int // tokenIDs that were marked (for reset)
dpCandidates []int // candidates requiring DP validation
// Reusable buffer for valid token indices (for GPU scatter)
validTokenIDs []int32
}
// EngineOption configures an Engine
type EngineOption func(*Engine)
// WithMaskCacheSize sets the mask cache size (default 1024)
func WithMaskCacheSize(size int) EngineOption {
return func(e *Engine) {
e.maskCache = newMaskCache(size)
}
}
// NewEngine creates a new constrained decoding engine.
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
// vocab is the list of token strings from the model's tokenizer.
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
if grammar == nil {
return nil, fmt.Errorf("grammar cannot be nil")
}
// Build analyzer and bridge
analyzer := newAnalyzer(vocab, grammar.matcher)
bridge := newBridge(grammar.pda, analyzer)
// Initialize config set from pda initial state
initialConfig := newConfigSet(grammar.pda.StartState, nil)
// Build token lookup map for O(1) AcceptString
tokenToID := make(map[string]int, len(vocab))
for i, tok := range vocab {
tokenToID[tok] = i
}
e := &Engine{
grammar: grammar,
bridge: bridge,
analyzer: analyzer,
configSet: initialConfig,
vocab: vocab,
tokenToID: tokenToID,
maskCache: newMaskCache(1024),
vocabSize: int32(len(vocab)),
candidateMark: make([]bool, len(vocab)),
touched: make([]int, 0, 10000),
validTokenIDs: make([]int32, 0, 10000),
}
// Apply options
for _, opt := range opts {
opt(e)
}
// Create the negative infinity mask and threshold
if e.vocabSize > 0 {
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
mlx.Keep(e.negInfMask)
e.threshold = mlx.NewScalarArray(0.5)
mlx.Keep(e.threshold)
}
return e, nil
}
// ApplyMask applies grammar constraints to logits.
// Returns logits with invalid tokens set to -inf.
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
sig := e.configSet.signature()
// Check state cache first (exact state match)
if cached := e.maskCache.get(sig); cached != nil {
condition := mlx.GreaterEqual(cached, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Compute valid tokens using candidate filtering:
// 1. Get valid terminal IDs from current grammar state
// 2. Get candidate tokens (those that START with valid terminals)
// 3. Run DP validation only on candidates
// This is O(candidates) instead of O(vocab_size)
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
// Use pre-partitioned token groups for fast candidate building
// This eliminates per-token branching - just direct slice appends
e.validTokenIDs = e.validTokenIDs[:0]
e.dpCandidates = e.dpCandidates[:0]
e.touched = e.touched[:0]
for _, tid := range validTerminalIDs {
groups := e.analyzer.terminalGroups(tid)
// Direct append of exact matches (no per-token check needed)
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
// Collect DP candidates (may have duplicates across terminals)
for _, tokenID := range groups.DPCandidates {
if !e.candidateMark[tokenID] {
e.candidateMark[tokenID] = true
e.dpCandidates = append(e.dpCandidates, tokenID)
e.touched = append(e.touched, tokenID)
}
}
}
// Reset marks for next call
for _, id := range e.touched {
e.candidateMark[id] = false
}
for _, tokenID := range e.dpCandidates {
if e.bridge.IsTokenValid(tokenID, e.configSet) {
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
}
}
// Create and cache the mask on GPU using index updates
mask := mlx.Zeros([]int32{e.vocabSize})
if len(e.validTokenIDs) > 0 {
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
values := mlx.Ones(int32(len(e.validTokenIDs)))
mask = mlx.PutAlongAxis(mask, indices, values, 0)
}
mlx.Keep(mask)
// Cache by state signature
e.maskCache.put(sig, mask)
// Apply mask
condition := mlx.GreaterEqual(mask, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Accept processes a token and updates the parser state.
// Returns true if the token was valid and accepted.
func (e *Engine) Accept(tokenID int) bool {
if tokenID < 0 || tokenID >= len(e.vocab) {
return false
}
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
if newConfig == nil {
return false
}
e.configSet = newConfig
return true
}
// AcceptString processes a token string directly.
// Returns true if the token was valid and accepted.
func (e *Engine) AcceptString(token string) bool {
if id, ok := e.tokenToID[token]; ok {
return e.Accept(id)
}
return false
}
// IsComplete returns true if the current state is accepting.
func (e *Engine) IsComplete() bool {
return e.bridge.isAccepting(e.configSet)
}
// Reset resets the engine to initial state.
func (e *Engine) Reset() {
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
}
// validTokens returns the indices of tokens that are currently valid.
func (e *Engine) validTokens() []int {
return e.bridge.validTokens(e.configSet)
}
// validTerminals returns the valid terminal patterns from the current state.
func (e *Engine) validTerminals() []string {
return e.bridge.validTerminals(e.configSet)
}
// Close releases MLX resources.
func (e *Engine) Close() {
if e.maskCache != nil {
e.maskCache.clear()
}
if e.negInfMask != nil {
e.negInfMask.Free()
}
if e.threshold != nil {
e.threshold.Free()
}
}

View File

@@ -0,0 +1,414 @@
//go:build mlx
package grammar
import (
"fmt"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newBenchEngine creates a JSON engine for benchmarks
func newBenchEngine(b *testing.B, vocab []string) *Engine {
b.Helper()
grammar, err := JSONGrammar()
if err != nil {
b.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
b.Fatalf("failed to create engine: %v", err)
}
return e
}
// Vocabulary sizes to test (matching real models)
var vocabSizes = []int{
32000, // Llama 2
128000, // Llama 3
256000, // Large models
}
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
func createBenchVocabN(n int) []string {
vocab := make([]string, n)
// JSON structural tokens (first 20)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"", "'",
}
for i, t := range jsonTokens {
if i < n {
vocab[i] = t
}
}
// String tokens (indices 20-1000)
stringIdx := 20
for i := 0; i < 980 && stringIdx+i < n; i++ {
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
}
// Number tokens (indices 1000-2000)
numberIdx := 1000
for i := 0; i < 1000 && numberIdx+i < n; i++ {
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
}
// Generic tokens (rest)
for i := 2000; i < n; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}
// ============ Core Performance Benchmarks ============
// BenchmarkApplyMask_32k measures mask application with 32k vocab
func BenchmarkApplyMask_32k(b *testing.B) {
benchmarkApplyMask(b, 32000)
}
// BenchmarkApplyMask_128k measures mask application with 128k vocab
func BenchmarkApplyMask_128k(b *testing.B) {
benchmarkApplyMask(b, 128000)
}
// BenchmarkApplyMask_256k measures mask application with 256k vocab
func BenchmarkApplyMask_256k(b *testing.B) {
benchmarkApplyMask(b, 256000)
}
func benchmarkApplyMask(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ReportMetric(float64(vocabSize), "vocab_size")
}
// ============ state-Dependent Benchmarks ============
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
e.AcceptString("{")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// BenchmarkApplyMaskMidObject measures mask in middle of object
func BenchmarkApplyMaskMidObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
// state: {"key": _value_
e.AcceptString("{")
e.AcceptString("\"key\"")
e.AcceptString(":")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ Token Sequence Benchmarks ============
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
func BenchmarkSequence_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
func BenchmarkSequence_NestedObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
func BenchmarkSequence_LargeArray(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Build sequence: [1, 2, 3, ..., 50]
sequence := []string{"["}
for i := 1; i <= 50; i++ {
sequence = append(sequence, fmt.Sprintf("%d", i))
if i < 50 {
sequence = append(sequence, ",")
}
}
sequence = append(sequence, "]")
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
func BenchmarkSequence_MixedTypes(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{",
"\"name\"", ":", "\"test\"", ",",
"\"count\"", ":", "42", ",",
"\"enabled\"", ":", "true", ",",
"\"data\"", ":", "null", ",",
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
"}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// ============ Component Benchmarks ============
// BenchmarkValidInputs measures pda valid input computation
func BenchmarkValidInputs(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.validTerminals()
}
}
// BenchmarkStateTransition measures pda state transition
func BenchmarkStateTransition(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
e.AcceptString(token)
}
}
}
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // Graph only, no eval
}
}
// BenchmarkNewEngine measures one-time engine initialization.
func BenchmarkNewEngine_32k(b *testing.B) {
benchmarkNewEngine(b, 32000)
}
func BenchmarkNewEngine_128k(b *testing.B) {
benchmarkNewEngine(b, 128000)
}
func benchmarkNewEngine(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
e := newBenchEngine(b, vocab)
e.Close()
}
}
// ============ Memory Benchmarks ============
func BenchmarkMemoryAllocs_32k(b *testing.B) {
benchmarkMemoryAllocs(b, 32000)
}
func BenchmarkMemoryAllocs_128k(b *testing.B) {
benchmarkMemoryAllocs(b, 128000)
}
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
// This simulates adding mask to LLM compute graph
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // No Eval - just build graph
}
}
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
var lastMasked *mlx.Array
for _, token := range sequence {
lastMasked = e.ApplyMask(logits) // Build graph only
e.AcceptString(token)
}
mlx.Eval(lastMasked) // Single eval at end
}
b.ReportMetric(float64(len(sequence)), "tokens")
}

689
x/grammar/engine_test.go Normal file
View File

@@ -0,0 +1,689 @@
//go:build mlx
package grammar
import (
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newTestEngine creates a JSON engine for testing
func newTestEngine(t testing.TB, vocab []string) *Engine {
t.Helper()
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
return e
}
// Mock vocabulary for testing
func testVocab() []string {
return []string{
"{", // 0: object start
"}", // 1: object end
"[", // 2: array start
"]", // 3: array end
":", // 4: colon
",", // 5: comma
"\"key\"", // 6: string (quoted)
"\"val\"", // 7: string (quoted)
"123", // 8: number
"-42.5", // 9: number
"true", // 10: boolean
"false", // 11: boolean
"null", // 12: null
" ", // 13: whitespace (should be ignored)
"\n", // 14: whitespace (should be ignored)
"subword", // 15: bare word (NOT valid JSON - requires quotes)
"hello", // 16: bare word (NOT valid JSON - requires quotes)
}
}
func TestNewEngine(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != int32(len(vocab)) {
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
}
// Verify grammar is set
if e.grammar == nil {
t.Error("grammar should not be nil")
}
// Verify analyzer is set
if e.analyzer == nil {
t.Error("analyzer should not be nil")
}
}
func TestEngineValidTokens(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// At start, any value type should be valid
validTokens := e.validTokens()
// Should include object start, array start, strings, numbers, booleans, null
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
// (JSON strings must be quoted)
expectedTokens := map[int]bool{
0: true, // {
2: true, // [
6: true, // "key"
7: true, // "val"
8: true, // 123
9: true, // -42.5
10: true, // true
11: true, // false
12: true, // null
}
// Check that expected tokens are present
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
for idx := range expectedTokens {
if !validSet[idx] {
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
}
}
if validSet[15] || validSet[16] {
t.Error("bare words should not be valid JSON at the start state")
}
}
func TestEngineAccept(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { should work
if !e.Accept(0) { // {
t.Error("should accept {")
}
// After {, valid tokens should be STRING or }
validTokens := e.validTokens()
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
// STRING tokens (indices 6, 7) and } (index 1) should be valid
if !validSet[1] {
t.Error("} should be valid after {")
}
if !validSet[6] && !validSet[7] {
t.Error("STRING should be valid after { (for keys)")
}
}
func TestEngineAcceptSequence(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept {"key": "val"}
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
for i, tokenID := range sequence {
if !e.Accept(tokenID) {
t.Fatalf("failed to accept token %d (%s) at position %d",
tokenID, vocab[tokenID], i)
}
}
if !e.IsComplete() {
t.Error("should be in complete state after valid JSON")
}
}
func TestEngineReset(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept some tokens
e.Accept(0) // {
e.Accept(1) // }
if !e.IsComplete() {
t.Error("should be complete after {}")
}
// Reset
e.Reset()
// Should be back to initial state
if e.IsComplete() {
t.Error("should not be complete after reset")
}
// Should be able to accept new sequence
if !e.Accept(0) { // {
t.Error("should accept { after reset")
}
}
func TestEngineInvalidTokenRejection(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { first
if !e.Accept(0) {
t.Fatal("should accept {")
}
// Now try to accept [ which is invalid after {
// (After {, only STRING or } are valid)
if e.Accept(2) { // [
t.Error("should not accept [ after { (expecting STRING or })")
}
}
func TestEngineAcceptString(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept using string directly
if !e.AcceptString("{") {
t.Error("should accept {")
}
if !e.AcceptString("\"key\"") {
t.Error("should accept string key")
}
if !e.AcceptString(":") {
t.Error("should accept :")
}
if !e.AcceptString("123") {
t.Error("should accept number")
}
if !e.AcceptString("}") {
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after valid JSON")
}
}
func TestJSONBackslashEscape(t *testing.T) {
vocab := []string{`"`, `\`, "n", "a"}
e := newTestEngine(t, vocab)
defer e.Close()
// Valid escape: "\n"
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if !e.AcceptString("n") {
t.Fatal("should accept escape code")
}
if !e.AcceptString(`"`) {
t.Fatal("should accept string end")
}
if !e.IsComplete() {
t.Error("should be complete after escaped string")
}
// Invalid escape: "\a"
e.Reset()
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if e.AcceptString("a") {
t.Error("should reject invalid escape code")
}
}
func TestEngineNegInfMask(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Verify negInfMask exists and has correct shape
if e.negInfMask == nil {
t.Fatal("negInfMask should not be nil")
}
}
func TestEngineMaskCache(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Create test logits
logits := mlx.Ones(int32(len(vocab)))
// Apply mask - should populate cache
_ = e.ApplyMask(logits)
// Check cache was populated
cacheSize := e.maskCache.size()
if cacheSize == 0 {
t.Error("mask cache should have at least one entry after ApplyMask")
}
}
func TestEngineEmptyVocab(t *testing.T) {
e := newTestEngine(t, []string{})
defer e.Close()
if e.vocabSize != 0 {
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
}
}
func TestEngineLargeVocab(t *testing.T) {
// Create a large vocabulary (simulating real model vocab)
vocab := make([]string, 32000)
for i := range vocab {
vocab[i] = "token"
}
// Add some actual JSON tokens
vocab[0] = "{"
vocab[1] = "}"
vocab[2] = "["
vocab[3] = "]"
vocab[4] = ":"
vocab[5] = ","
vocab[6] = "\"test\""
vocab[7] = "123"
vocab[8] = "true"
vocab[9] = "false"
vocab[10] = "null"
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != 32000 {
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
}
// Test that it still works correctly
if !e.Accept(0) { // {
t.Error("should accept {")
}
if !e.Accept(1) { // }
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after {}")
}
}
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
func TestE2E_JSONDecoding(t *testing.T) {
// Create a realistic vocabulary with JSON tokens
vocab := []string{
// Structural tokens
"{", "}", "[", "]", ":", ",",
// Keywords
"true", "false", "null",
// Quoted strings
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
`"hello"`, `"world"`, `"test"`,
// Numbers
"0", "1", "2", "3", "42", "123", "-1", "-42",
// Whitespace
" ", "\n", "\t",
// Multi-terminal tokens (span multiple JSON lexemes)
`"key":`, `},`, `],`, `{"`, `["`,
// Partial/invalid tokens (should be rejected)
"invalid", "foo", "bar",
}
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
// Simple values
{"empty object", []string{"{", "}"}, true},
{"empty array", []string{"[", "]"}, true},
{"true literal", []string{"true"}, true},
{"null literal", []string{"null"}, true},
{"number", []string{"42"}, true},
{"negative number", []string{"-42"}, true},
{"quoted string", []string{`"hello"`}, true},
// Objects
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
// Arrays
{"array of numbers", []string{"[", "42", "]"}, true},
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
// Nested structures
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
// Invalid sequences
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
{"bare word", []string{"invalid"}, false}, // not valid JSON
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
// Process each token
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return // Already reported error
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
// For invalid sequences, we expect either rejection or incomplete
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
// Simple expression grammar: expr = term { ("+" | "-") term }
// term = number | "(" expr ")"
// number = digit { digit }
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
exprGrammar := `
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
`
grammar, err := ParseEBNF(exprGrammar, "expr")
if err != nil {
t.Fatalf("failed to parse expression grammar: %v", err)
}
// Vocabulary for expression tokens
vocab := []string{
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "*", "/",
"(", ")",
// Multi-digit numbers as single tokens
"10", "42", "100", "123",
// Invalid tokens
"x", "y", "invalid",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single digit", []string{"5"}, true},
{"multi-digit", []string{"1", "2", "3"}, true},
{"addition", []string{"1", "+", "2"}, true},
{"subtraction", []string{"5", "-", "3"}, true},
{"multiplication", []string{"2", "*", "3"}, true},
{"division", []string{"8", "/", "2"}, true},
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
// Invalid
{"just operator", []string{"+"}, false},
{"double operator", []string{"1", "+", "+", "2"}, false},
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
{"variable", []string{"x"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
func TestE2E_IdentifierGrammar(t *testing.T) {
// Identifier grammar using character ranges
identGrammar := `
ident = letter { letter | digit } .
letter = "a" … "z" | "A" … "Z" | "_" .
digit = "0" … "9" .
`
grammar, err := ParseEBNF(identGrammar, "ident")
if err != nil {
t.Fatalf("failed to parse identifier grammar: %v", err)
}
// Vocabulary with letters and digits
vocab := []string{
"a", "b", "c", "x", "y", "z",
"A", "B", "C", "X", "Y", "Z",
"_",
"0", "1", "2", "9",
// Multi-char tokens
"foo", "bar", "myVar", "test123",
// Invalid starting chars
"1abc", "123",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single letter", []string{"a"}, true},
{"uppercase", []string{"A"}, true},
{"underscore", []string{"_"}, true},
{"multi-letter", []string{"a", "b", "c"}, true},
{"letter then digit", []string{"x", "1"}, true},
{"underscore prefix", []string{"_", "a", "1"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass && allAccepted && !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
})
}
}
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
func TestE2E_UnicodeRange(t *testing.T) {
greekGrammar := `
greek = "α" … "ω" .
`
grammar, err := ParseEBNF(greekGrammar, "greek")
if err != nil {
t.Fatalf("failed to parse unicode grammar: %v", err)
}
vocab := []string{"α", "β", "ω", "a"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
if !engine.AcceptString("β") {
t.Error("should accept beta")
}
if !engine.IsComplete() {
t.Error("should be complete after single rune")
}
engine.Reset()
if engine.AcceptString("a") {
t.Error("should reject ASCII outside unicode range")
}
}
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
func TestE2E_NondeterminismPreserved(t *testing.T) {
// This grammar has nondeterminism: "ab" could be parsed as
// a single token or as two tokens "a" "b"
ambiguousGrammar := `
start = item item .
item = "a" | "b" | "ab" .
`
grammar, err := ParseEBNF(ambiguousGrammar, "start")
if err != nil {
t.Fatalf("failed to parse grammar: %v", err)
}
// Vocabulary with both single and combined tokens
vocab := []string{"a", "b", "ab"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
// Test: "ab" "a" should be valid (ab as first item, a as second)
t.Run("ab then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("ab") {
t.Error("should accept ab")
}
if !engine.AcceptString("a") {
t.Error("should accept a after ab")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then ab", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept a")
}
if !engine.AcceptString("ab") {
t.Error("should accept ab after a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept first a")
}
if !engine.AcceptString("a") {
t.Error("should accept second a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
}

614
x/grammar/grammar.go Normal file
View File

@@ -0,0 +1,614 @@
//go:build mlx
// Package grammar provides GPU-accelerated constrained decoding using MLX.
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
// For JSON Schema conversion, see the grammar/schema subpackage.
package grammar
import (
"encoding/binary"
"fmt"
"io"
"strings"
"golang.org/x/exp/ebnf"
)
// stackSymbol represents a symbol that can be pushed onto the pda stack.
type stackSymbol int
const (
stackEmpty stackSymbol = iota
// Additional stack symbols will be generated per-grammar
)
// state represents a pda state.
type state int
const (
stateError state = -1
stateStart state = 0
stateAccept state = 1
// Additional states will be generated per-grammar
)
// transition represents a pda transition.
// On input matching Pattern, from FromState with stackTop:
// - Move to ToState
// - Pop StackPop symbols, push StackPush symbols
type transition struct {
FromState state
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
Pattern string // Input pattern to match (token or character class)
ToState state
StackPop int // Number of symbols to pop
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
}
// pda represents a compiled pushdown automaton.
type pda struct {
States int // Total number of states
StackSymbols int // Total number of stack symbols
StartState state // Initial state
AcceptStates map[state]bool // Set of accepting states
Transitions map[state][]transition // Transitions indexed by from-state
// For token-level matching
Terminals []string // All terminal symbols (patterns to match)
}
// newPDA creates an empty pda.
func newPDA() *pda {
return &pda{
States: 2, // Error and Start
StackSymbols: 1, // Empty
StartState: stateStart,
AcceptStates: make(map[state]bool),
Transitions: make(map[state][]transition),
Terminals: make([]string, 0),
}
}
// addState adds a new state and returns its ID.
func (p *pda) addState() state {
s := state(p.States)
p.States++
return s
}
// addStackSymbol adds a new stack symbol and returns its ID.
func (p *pda) addStackSymbol() stackSymbol {
s := stackSymbol(p.StackSymbols)
p.StackSymbols++
return s
}
// addTransition adds a transition to the pda.
func (p *pda) addTransition(t transition) {
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
}
// addTerminal registers a terminal pattern and returns its index.
func (p *pda) addTerminal(pattern string) int {
for i, t := range p.Terminals {
if t == pattern {
return i
}
}
p.Terminals = append(p.Terminals, pattern)
return len(p.Terminals) - 1
}
// compiler compiles EBNF grammars to PDAs.
type compiler struct {
grammar ebnf.Grammar
pda *pda
// Maps production names to their entry/exit states
prodEntry map[string]state
prodExit map[string]state
}
// compile parses an EBNF grammar and compiles it to a pda.
func compile(name string, src io.Reader, start string) (*pda, error) {
grammar, err := ebnf.Parse(name, src)
if err != nil {
return nil, fmt.Errorf("parse grammar: %w", err)
}
if err := ebnf.Verify(grammar, start); err != nil {
return nil, fmt.Errorf("verify grammar: %w", err)
}
c := &compiler{
grammar: grammar,
pda: newPDA(),
prodEntry: make(map[string]state),
prodExit: make(map[string]state),
}
// Create entry/exit states for each production
for name := range grammar {
c.prodEntry[name] = c.pda.addState()
c.prodExit[name] = c.pda.addState()
}
// compile each production
for name, prod := range grammar {
if err := c.compileProduction(name, prod); err != nil {
return nil, fmt.Errorf("compile production %q: %w", name, err)
}
}
// Set start state to entry of start production
if entry, ok := c.prodEntry[start]; ok {
// Add epsilon transition from pda start to grammar start
c.pda.addTransition(transition{
FromState: stateStart,
Pattern: "", // epsilon
ToState: entry,
})
} else {
return nil, fmt.Errorf("start production %q not found", start)
}
// Mark exit of start production as accepting
if exit, ok := c.prodExit[start]; ok {
c.pda.AcceptStates[exit] = true
}
return c.pda, nil
}
// compileString is a convenience function to compile from a string.
func compileString(grammar string, start string) (*pda, error) {
return compile("grammar", strings.NewReader(grammar), start)
}
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
entry := c.prodEntry[name]
exit := c.prodExit[name]
return c.compileExpr(prod.Expr, entry, exit)
}
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
switch e := expr.(type) {
case *ebnf.Name:
return c.compileName(e, entry, exit)
case *ebnf.Token:
return c.compileToken(e, entry, exit)
case ebnf.Sequence:
return c.compileSequence(e, entry, exit)
case ebnf.Alternative:
return c.compileAlternative(e, entry, exit)
case *ebnf.Option:
return c.compileOption(e, entry, exit)
case *ebnf.Repetition:
return c.compileRepetition(e, entry, exit)
case *ebnf.Group:
return c.compileExpr(e.Body, entry, exit)
case *ebnf.Range:
return c.compileRange(e, entry, exit)
case nil:
// Empty production - direct epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
default:
return fmt.Errorf("unsupported expression type: %T", expr)
}
}
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
// Reference to another production
prodName := n.String
prodEntry, ok := c.prodEntry[prodName]
if !ok {
return fmt.Errorf("undefined production: %s", prodName)
}
prodExit := c.prodExit[prodName]
// Use a unique stack symbol per call site so returns are unambiguous.
stackSym := c.pda.addStackSymbol()
// Push return address, go to production entry
c.pda.addTransition(transition{
FromState: entry,
Pattern: "", // epsilon
ToState: prodEntry,
StackPush: []stackSymbol{stackSym},
})
// On production exit, pop and return
c.pda.addTransition(transition{
FromState: prodExit,
stackTop: stackSym,
Pattern: "", // epsilon
ToState: exit,
StackPop: 1,
})
return nil
}
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
// terminal symbol - add transition that consumes this token
pattern := t.String
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
if len(seq) == 0 {
// Empty sequence - epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
}
// Chain: entry -> s1 -> s2 -> ... -> exit
current := entry
for i, expr := range seq {
var next state
if i == len(seq)-1 {
next = exit
} else {
next = c.pda.addState()
}
if err := c.compileExpr(expr, current, next); err != nil {
return err
}
current = next
}
return nil
}
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
// Each alternative goes from entry to exit
for _, expr := range alt {
if err := c.compileExpr(expr, entry, exit); err != nil {
return err
}
}
return nil
}
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
// Optional: can skip (epsilon) or take the body
// Epsilon transition (skip)
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Or take the body
return c.compileExpr(opt.Body, entry, exit)
}
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
// Repetition {body}: zero or more
// entry -> exit (skip)
// entry -> body -> entry (loop back)
// Skip transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Loop: entry -> (body) -> entry
return c.compileExpr(rep.Body, entry, entry)
}
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
begin := strings.Trim(r.Begin.String, "\"")
end := strings.Trim(r.End.String, "\"")
// Unescape bounds first (so "\u03b1" works)
beginUnesc, err := unescapeLiteral(begin)
if err != nil {
return fmt.Errorf("invalid range begin: %w", err)
}
endUnesc, err := unescapeLiteral(end)
if err != nil {
return fmt.Errorf("invalid range end: %w", err)
}
// Validate as single runes (not bytes) for Unicode support
beginRunes := []rune(beginUnesc)
endRunes := []rune(endUnesc)
if len(beginRunes) != 1 || len(endRunes) != 1 {
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
}
// Use unescaped rune strings in pattern (consistent with matcher)
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
// runtime represents a pda execution instance.
type runtime struct {
pda *pda
state state
stack []stackSymbol
}
// newRuntime creates a new pda runtime.
func newRuntime(pda *pda) *runtime {
return &runtime{
pda: pda,
state: pda.StartState,
stack: make([]stackSymbol, 0, 32),
}
}
// stackTop returns the top of the stack, or stackEmpty if empty.
func (r *runtime) stackTop() stackSymbol {
if len(r.stack) == 0 {
return stackEmpty
}
return r.stack[len(r.stack)-1]
}
// isAccepting returns true if we can reach an accepting state via epsilon transitions
// with an empty stack.
func (r *runtime) isAccepting() bool {
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
}
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if r.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
// Avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
for _, t := range r.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
// Check stack constraint
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Simulate stack operations
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 && len(newStack) >= t.StackPop {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if r.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// Reset resets the runtime to initial state.
func (r *runtime) Reset() {
r.state = r.pda.StartState
r.stack = r.stack[:0]
}
// validInputs returns all valid input patterns from current state.
func (r *runtime) validInputs() []string {
var valid []string
seen := make(map[string]bool)
visited := make(map[stateStackKey]bool)
// Make a copy of the stack for simulation
simStack := make([]stackSymbol, len(r.stack))
copy(simStack, r.stack)
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
return valid
}
// stateStackKey is used to detect cycles in epsilon closure
type stateStackKey struct {
state state
stackSig string
}
func stackSignature(stack []stackSymbol) string {
if len(stack) == 0 {
return ""
}
buf := make([]byte, len(stack)*8)
for i, sym := range stack {
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
}
return string(buf)
}
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
// Get stack top for comparisons
stackTop := stackEmpty
if len(simStack) > 0 {
stackTop = simStack[len(simStack)-1]
}
// Check for cycles to avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
if visited[key] {
return
}
visited[key] = true
transitions := r.pda.Transitions[state]
for _, t := range transitions {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.Pattern == "" {
// Epsilon transition - simulate stack operations
newStack := make([]stackSymbol, len(simStack))
copy(newStack, simStack)
// Pop
if t.StackPop > 0 {
if len(newStack) < t.StackPop {
continue // Can't pop, skip this transition
}
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
} else {
// terminal - add if not seen
if !seen[t.Pattern] {
seen[t.Pattern] = true
*valid = append(*valid, t.Pattern)
}
}
}
}
// matchesPattern checks if input matches a pattern.
// Patterns can be:
// - Exact strings: "a", "{", "true"
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
func matchesPattern(input, pattern string) bool {
// Exact match
if input == pattern {
return true
}
// Check for character range pattern [X-Y]
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
if len(input) != 1 {
return false
}
ch := input[0]
low := pattern[1]
high := pattern[3]
return ch >= low && ch <= high
}
return false
}
// Accept tries to accept an input, returning true if successful.
func (r *runtime) Accept(input string) bool {
return r.accept(input, make(map[stateStackKey]bool))
}
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
if visited[key] {
return false
}
visited[key] = true
transitions := r.pda.Transitions[r.state]
// First, process any epsilon transitions to reach a state that can accept input
// This is a simplified version - full implementation would need epsilon closure
for _, t := range transitions {
if matchesPattern(input, t.Pattern) {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Apply transition
r.applyTransition(t)
return true
}
}
// Try epsilon transitions first
for _, t := range transitions {
if t.Pattern == "" {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Save state for backtracking
oldState := r.state
oldStack := make([]stackSymbol, len(r.stack))
copy(oldStack, r.stack)
r.applyTransition(t)
if r.accept(input, visited) {
return true
}
// Backtrack
r.state = oldState
r.stack = oldStack
}
}
return false
}
func (r *runtime) applyTransition(t transition) {
// Pop
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
r.stack = r.stack[:len(r.stack)-t.StackPop]
}
// Push
r.stack = append(r.stack, t.StackPush...)
// Move to new state
r.state = t.ToState
}

540
x/grammar/grammar_test.go Normal file
View File

@@ -0,0 +1,540 @@
//go:build mlx
package grammar
import (
"testing"
)
func TestCompileSimpleGrammar(t *testing.T) {
// Simple grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
if pda == nil {
t.Fatal("pda is nil")
}
// Should have terminals "a" and "b"
if len(pda.Terminals) != 2 {
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
}
// Test runtime
rt := newRuntime(pda)
// Should accept "a" then "b"
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be in accepting state")
}
}
func TestCompileAlternative(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Test accepting "a"
rt := newRuntime(pda)
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'a'")
}
// Test accepting "b"
rt.Reset()
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// Test rejecting "c"
rt.Reset()
if rt.Accept("c") {
t.Error("should not accept 'c'")
}
}
func TestCompileRepetition(t *testing.T) {
// Grammar: S = {"a"} .
grammar := `S = {"a"} .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Empty should be accepted (zero repetitions)
rt := newRuntime(pda)
if !rt.isAccepting() {
t.Error("empty should be accepting")
}
// "a" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept first 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after one 'a'")
}
// "aa" should be accepted
if !rt.Accept("a") {
t.Error("should accept second 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after two 'a's")
}
}
func TestCompileOption(t *testing.T) {
// Grammar: S = ["a"] "b" .
grammar := `S = ["a"] "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "b" alone should be accepted
rt := newRuntime(pda)
if !rt.Accept("b") {
t.Error("should accept 'b' alone")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// "ab" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b' after 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'ab'")
}
}
func TestCompileRecursive(t *testing.T) {
// Grammar with recursion: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "x" should be accepted
rt := newRuntime(pda)
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'x'")
}
// "(x)" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x' inside parens")
}
if !rt.Accept(")") {
t.Error("should accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x)'")
}
// "((x))" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept first '('")
}
if !rt.Accept("(") {
t.Error("should accept second '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.Accept(")") {
t.Error("should accept first ')'")
}
if !rt.Accept(")") {
t.Error("should accept second ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '((x))'")
}
}
func TestValidInputs(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
valid := rt.validInputs()
// Should have both "a" and "b" as valid
hasA, hasB := false, false
for _, v := range valid {
if v == "a" {
hasA = true
}
if v == "b" {
hasB = true
}
}
if !hasA {
t.Error("'a' should be valid input")
}
if !hasB {
t.Error("'b' should be valid input")
}
}
// TestValidInputsAfterAccept tests that validInputs returns correct values
// after accepting tokens, ensuring proper stack simulation.
func TestValidInputsAfterAccept(t *testing.T) {
// Grammar: S = "a" "b" "c" .
grammar := `S = "a" "b" "c" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "a" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "a" {
t.Errorf("initially expected only 'a', got %v", valid)
}
// After accepting "a", only "b" should be valid
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "b" {
t.Errorf("after 'a', expected only 'b', got %v", valid)
}
// After accepting "b", only "c" should be valid
if !rt.Accept("b") {
t.Fatal("failed to accept 'b'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "c" {
t.Errorf("after 'ab', expected only 'c', got %v", valid)
}
}
// TestValidInputsWithRepetitionInProduction tests the critical case where
// a repetition exists inside a called production. This requires proper
// stack simulation to determine when closing symbols are valid.
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
// Grammar similar to JSON:
// S = "(" items ")" .
// items = item { "," item } .
// item = "x" .
grammar := `
S = "(" items ")" .
items = item { "," item } .
item = "x" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "(" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "(" {
t.Errorf("initially expected only '(', got %v", valid)
}
// Accept "("
if !rt.Accept("(") {
t.Fatal("failed to accept '('")
}
// After "(", should be able to accept "x" (item)
valid = rt.validInputs()
hasX := false
for _, v := range valid {
if v == "x" {
hasX = true
}
}
if !hasX {
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
}
// Accept first item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept 'x'")
}
// After "(x", should be able to accept "," (more items) OR ")" (end)
valid = rt.validInputs()
hasComma, hasClose := false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
}
// Accept comma for another item
if !rt.Accept(",") {
t.Fatal("failed to accept ','")
}
// After "(x,", should only be able to accept "x" (next item)
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "x" {
t.Errorf("after '(x,', expected only 'x', got %v", valid)
}
// Accept second item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept second 'x'")
}
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
// This tests the stack simulation fix - we need to properly
// follow epsilon transitions through the production call stack.
valid = rt.validInputs()
hasComma, hasClose = false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
}
// Close with ")"
if !rt.Accept(")") {
t.Fatal("failed to accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x,x)'")
}
}
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
func TestValidInputsNestedCalls(t *testing.T) {
// Grammar: A = "start" B "end" . B = "middle" .
grammar := `
A = "start" B "end" .
B = "middle" .
`
pda, err := compileString(grammar, "A")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// After "start", should accept "middle" (from B)
rt.Accept("start")
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "middle" {
t.Errorf("after 'start', expected 'middle', got %v", valid)
}
// After "start middle", should accept "end"
rt.Accept("middle")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "end" {
t.Errorf("after 'start middle', expected 'end', got %v", valid)
}
}
func TestReturnAddressDisambiguation(t *testing.T) {
// Grammar where the same production is called from different contexts:
// S = A "x" | "c" A "y" .
// A = "a" .
grammar := `
S = A "x" | "c" A "y" .
A = "a" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
if !rt.Accept("c") {
t.Fatal("failed to accept 'c'")
}
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "y" {
t.Errorf("after 'ca', expected only 'y', got %v", valid)
}
rt.Reset()
rt.Accept("c")
rt.Accept("a")
if rt.Accept("x") {
t.Error("should not accept 'x' after 'ca'")
}
}
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
// which heavily exercise the stack simulation.
func TestValidInputsRecursiveWithStack(t *testing.T) {
// Grammar: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially: "(" or "x" should be valid
valid := rt.validInputs()
hasParen, hasX := false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("initially expected '(' and 'x', got %v", valid)
}
// After "(": "(" or "x" should be valid (nested S)
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
}
// After "((": "(" or "x" should still be valid
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
}
// After "((x": only ")" should be valid
rt.Accept("x")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x', expected only ')', got %v", valid)
}
// After "((x)": only ")" should be valid (closing outer)
rt.Accept(")")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x)', expected only ')', got %v", valid)
}
}
// TestRejectionAfterValid tests that invalid inputs are rejected
// at various points in the grammar.
func TestRejectionAfterValid(t *testing.T) {
// Grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// "b" should be rejected initially
if rt.Accept("b") {
t.Error("'b' should be rejected initially")
}
// Accept "a"
rt.Accept("a")
// "a" should be rejected after "a"
if rt.Accept("a") {
t.Error("'a' should be rejected after 'a'")
}
// "c" should be rejected (not in grammar)
if rt.Accept("c") {
t.Error("'c' should be rejected (not in grammar)")
}
}

View File

@@ -0,0 +1,56 @@
# Example Grammars
This directory contains example EBNF grammars for constrained decoding.
## Usage
```bash
go run -tags mlx ./x/imagegen/cmd/engine/ \
-model /path/to/model \
-prompt "Your prompt" \
-grammar x/grammar/grammars/json.ebnf \
-grammar-start value
```
## Available Grammars
| File | Start Rule | Description |
|------|------------|-------------|
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
| `identifier.ebnf` | `ident` | Programming language identifiers |
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
| `list.ebnf` | `list` | Comma-separated word list |
| `yesno.ebnf` | `response` | Simple yes/no responses |
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
| `email.ebnf` | `email` | Basic email addresses |
| `phone.ebnf` | `phone` | US phone numbers |
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
## Grammar Syntax
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
The grammars use EBNF notation:
- `=` defines a production rule
- `|` is alternation (or)
- `{ }` is repetition (zero or more)
- `[ ]` is optional (zero or one)
- `" "` is a literal string
- `…` is a character range (e.g., `"a" … "z"`)
- `.` ends a production
## Writing Custom Grammars
1. Define your grammar in a `.ebnf` file
2. Choose a start rule name
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
Example custom grammar for RGB colors:
```ebnf
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .
```

View File

@@ -0,0 +1,7 @@
expr = term { " OR " term } .
term = factor { " AND " factor } .
factor = "NOT " factor | atom | "(" expr ")" .
atom = "true" | "false" | ident .
ident = letter { letter | digit } .
letter = "a" "z" | "A" "Z" .
digit = "0" "9" .

View File

@@ -0,0 +1,6 @@
date = year "-" month "-" day .
year = digit digit digit digit .
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,5 @@
email = localpart "@" domain .
localpart = word { "." word } .
domain = word { "." word } .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .

View File

@@ -0,0 +1,7 @@
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = [ "-" ] digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,4 @@
color = "#" ( hex6 | hex3 ) .
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hex3 = hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .

View File

@@ -0,0 +1,3 @@
ident = letter { letter | digit | "_" } .
letter = "a" "z" | "A" "Z" | "_" .
digit = "0" "9" .

View File

@@ -0,0 +1,16 @@
value = object | array | string | number | "true" | "false" | "null" .
object = "{" [ members ] "}" .
members = pair { "," pair } .
pair = string ":" value .
array = "[" [ elements ] "]" .
elements = value { "," value } .
string = "\"" { char } "\"" .
char = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
onenine = "1" "9" .
digit = "0" "9" .

View File

@@ -0,0 +1,27 @@
root = array .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,4 @@
list = item { ", " item } .
item = word .
word = letter { letter } .
letter = "a" "z" | "A" "Z" .

View File

@@ -0,0 +1,19 @@
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,15 @@
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
email_field = "\"email\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = " " | "!" | "#" "~" .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,7 @@
phone = parenformat | dashformat .
parenformat = "(" areacode ") " exchange "-" subscriber .
dashformat = areacode "-" exchange "-" subscriber .
areacode = digit digit digit .
exchange = digit digit digit .
subscriber = digit digit digit digit .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,11 @@
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
scheme = "http" | "https" .
host = word { "." word } .
port = digit { digit } .
path = "/" { pathseg } .
pathseg = word [ "/" ] .
query = "?" param { "&" param } .
param = word "=" word .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .
digit = "0" "9" .

View File

@@ -0,0 +1,3 @@
response = affirmative | negative .
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .

69
x/grammar/json.go Normal file
View File

@@ -0,0 +1,69 @@
//go:build mlx
package grammar
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
// Based on https://www.json.org/json-en.html
//
// This grammar operates at the character level. The engine validates
// tokens by matching them as sequences of these character-level terminals.
const JSONGrammarEBNF = `
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
const JSONObjectGrammarEBNF = `
json = object .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`

726
x/grammar/schema/schema.go Normal file
View File

@@ -0,0 +1,726 @@
//go:build mlx
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
package schema
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
"github.com/ollama/ollama/x/grammar"
)
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
// See: https://platform.openai.com/docs/guides/structured-outputs
type schemaNode struct {
// Core types
Type interface{} `json:"type"` // string, []string, or nil
// Object properties
Properties map[string]*schemaNode `json:"properties"`
Required []string `json:"required"`
AdditionalProperties interface{} `json:"additionalProperties"`
// Array properties
Items *schemaNode `json:"items"`
MinItems *int `json:"minItems"`
MaxItems *int `json:"maxItems"`
// String properties
Pattern string `json:"pattern"` // Regex pattern
Format string `json:"format"` // date-time, email, uuid, etc.
// Number properties (noted but not enforced in grammar - validated post-generation)
Minimum *float64 `json:"minimum"`
Maximum *float64 `json:"maximum"`
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
MultipleOf *float64 `json:"multipleOf"`
// Enum and const
Enum []interface{} `json:"enum"`
Const interface{} `json:"const"`
// Composition
AnyOf []*schemaNode `json:"anyOf"`
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
// References and definitions
Ref string `json:"$ref"`
Defs map[string]*schemaNode `json:"$defs"`
// Description (ignored for grammar but useful for docs)
Description string `json:"description"`
}
// converter handles JSON Schema to EBNF conversion with state.
type converter struct {
schema *schemaNode
definitions map[string]*schemaNode // Resolved $defs
usedTypes map[string]bool
rules []string
ruleNum int
definedRefs map[string]bool // Track which refs we've already defined as rules
}
// EBNF converts a JSON Schema to EBNF grammar
func EBNF(schemaJSON string) (string, error) {
var schema schemaNode
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
}
conv := &converter{
schema: &schema,
definitions: schema.Defs,
usedTypes: make(map[string]bool),
definedRefs: make(map[string]bool),
}
return conv.convert()
}
func (c *converter) convert() (string, error) {
var b strings.Builder
// Generate root rule
rootExpr := c.schemaToExpr(c.schema, "root")
b.WriteString("root = ")
b.WriteString(rootExpr)
b.WriteString(" .\n")
// Add generated rules (refs, items, etc.)
for _, rule := range c.rules {
b.WriteString(rule)
b.WriteString("\n")
}
// Add primitives based on usage
c.addPrimitives(&b)
return b.String(), nil
}
func (c *converter) addPrimitives(b *strings.Builder) {
if c.usedTypes["string"] {
b.WriteString(`
string = "\"" { character } "\"" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] {
b.WriteString(`
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
`)
}
if c.usedTypes["number"] {
b.WriteString(`
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
`)
}
if c.usedTypes["integer"] {
b.WriteString(`
int = [ "-" ] ( "0" | onenine { digit } ) .
`)
}
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
b.WriteString(`
digit = "0" … "9" .
`)
}
// onenine only needed for number/integer, not for digit-only formats
if c.usedTypes["number"] || c.usedTypes["integer"] {
b.WriteString(`onenine = "1" … "9" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
b.WriteString(`
hex = "0" … "9" | "A" … "F" | "a" … "f" .
`)
}
if c.usedTypes["ws"] {
b.WriteString(`
ws = { " " | "\t" | "\n" | "\r" } .
`)
}
}
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
if schema == nil {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
}
// Handle $ref first
if schema.Ref != "" {
return c.resolveRef(schema.Ref)
}
// Handle const
if schema.Const != nil {
return c.constToExpr(schema.Const)
}
// Handle enum
if len(schema.Enum) > 0 {
return c.enumToExpr(schema.Enum)
}
// Handle anyOf / oneOf
if len(schema.AnyOf) > 0 {
return c.anyOfToExpr(schema.AnyOf, name)
}
if len(schema.OneOf) > 0 {
return c.anyOfToExpr(schema.OneOf, name)
}
// Handle type
types := c.getTypes(schema.Type)
if len(types) == 0 {
// No type specified, could be anything
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | \"true\" | \"false\" | \"null\" )"
}
if len(types) == 1 {
return c.typeToExpr(types[0], schema, name)
}
// Multiple types (e.g., ["string", "null"])
var parts []string
for _, t := range types {
parts = append(parts, c.typeToExpr(t, schema, name))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
switch typeName {
case "object":
return c.objectToExpr(schema, name)
case "array":
return c.arrayToExpr(schema, name)
case "string":
return c.stringToExpr(schema, name)
case "number":
c.usedTypes["number"] = true
return "number"
case "integer":
c.usedTypes["integer"] = true
c.usedTypes["digit"] = true
return "int"
case "boolean":
return `( "true" | "false" )`
case "null":
return `"null"`
default:
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "string"
}
}
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
if len(schema.Properties) == 0 {
return `"{" ws "}"`
}
// Sort properties for deterministic output
// Required properties come first, in their required order
var propOrder []string
requiredSet := make(map[string]bool)
for _, r := range schema.Required {
requiredSet[r] = true
propOrder = append(propOrder, r)
}
// Add any non-required properties (though OpenAI requires all to be required)
var optionalProps []string
for propName := range schema.Properties {
if !requiredSet[propName] {
optionalProps = append(optionalProps, propName)
}
}
sort.Strings(optionalProps)
propOrder = append(propOrder, optionalProps...)
var propExprs []string
first := true
for _, propName := range propOrder {
propSchema, exists := schema.Properties[propName]
if !exists {
continue
}
propExpr := c.schemaToExpr(propSchema, propName)
prefix := ""
if !first {
prefix = `"," ws `
}
first = false
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
}
if len(propExprs) == 0 {
return `"{" ws "}"`
}
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
}
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
itemExpr := "value"
if schema.Items != nil {
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
} else {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
}
// Create item rule
c.ruleNum++
itemRule := fmt.Sprintf("item%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
// Handle minItems/maxItems
if schema.MinItems != nil || schema.MaxItems != nil {
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
}
// Default: zero or more items
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
min := 0
max := -1 // unlimited
if minItems != nil {
min = *minItems
}
if maxItems != nil {
max = *maxItems
}
if min == 0 && max < 0 {
// No constraints
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
if min == 0 && max == 0 {
return `"[" ws "]"`
}
// Build pattern for bounded array
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
var parts []string
// Required items
for i := 0; i < min; i++ {
if i > 0 {
parts = append(parts, `"," ws`)
}
parts = append(parts, itemRule)
}
// Optional items up to max
if max > min {
for i := min; i < max; i++ {
if i == 0 {
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
}
}
// Close all optional brackets
for i := min; i < max; i++ {
parts = append(parts, "]")
}
} else if max < 0 {
// Unlimited after min
if min > 0 {
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
}
}
if min == 0 {
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
}
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
}
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
// Handle format
if schema.Format != "" {
return c.formatToExpr(schema.Format)
}
// Handle pattern (regex)
if schema.Pattern != "" {
return c.patternToExpr(schema.Pattern, name)
}
// Default string
c.usedTypes["string"] = true
if name == "root" {
c.usedTypes["character"] = true
return `"\"" { character } "\""`
}
return "string"
}
func (c *converter) formatToExpr(format string) string {
switch format {
case "date":
// YYYY-MM-DD
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("date%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
return ruleName
case "time":
// HH:MM:SS
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("time%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
return ruleName
case "date-time":
// YYYY-MM-DDTHH:MM:SSZ or with offset
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
return ruleName
case "email":
// Simplified email pattern
c.ruleNum++
ruleName := fmt.Sprintf("email%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
return ruleName
case "uuid":
// 8-4-4-4-12 hex pattern
c.ruleNum++
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
c.usedTypes["hex"] = true
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
return ruleName
case "ipv4":
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
return ruleName
case "uri", "hostname":
// Fallback to general string for complex formats
c.usedTypes["string"] = true
return "string"
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) patternToExpr(pattern string, name string) string {
// Try to convert simple regex patterns to EBNF
// This handles common cases; complex regex falls back to string
// Remove anchors
pattern = strings.TrimPrefix(pattern, "^")
pattern = strings.TrimSuffix(pattern, "$")
// Try to parse and convert
expr, ok := c.regexToEBNF(pattern)
if !ok {
// Fallback to general string
c.usedTypes["string"] = true
return "string"
}
c.ruleNum++
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
return ruleName
}
func (c *converter) regexToEBNF(pattern string) (string, bool) {
// Simple regex to EBNF converter
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
var result strings.Builder
i := 0
for i < len(pattern) {
ch := pattern[i]
switch ch {
case '[':
// Character class
end := strings.Index(pattern[i:], "]")
if end == -1 {
return "", false
}
class := pattern[i+1 : i+end]
ebnfClass, ok := c.charClassToEBNF(class)
if !ok {
return "", false
}
result.WriteString(ebnfClass)
i += end + 1
case '(':
// Group - find matching )
depth := 1
start := i + 1
j := start
for j < len(pattern) && depth > 0 {
if pattern[j] == '(' {
depth++
} else if pattern[j] == ')' {
depth--
}
j++
}
if depth != 0 {
return "", false
}
groupContent := pattern[start : j-1]
groupExpr, ok := c.regexToEBNF(groupContent)
if !ok {
return "", false
}
result.WriteString("( ")
result.WriteString(groupExpr)
result.WriteString(" )")
i = j
case '|':
result.WriteString(" | ")
i++
case '+':
// One or more - wrap previous in { } and add one required
// This is a simplification
return "", false // TODO: handle properly
case '*':
// Zero or more - need to wrap previous
return "", false // TODO: handle properly
case '?':
// Optional - need to wrap previous in [ ]
return "", false // TODO: handle properly
case '\\':
// Escape sequence
if i+1 >= len(pattern) {
return "", false
}
next := pattern[i+1]
switch next {
case 'd':
result.WriteString("digit")
c.usedTypes["digit"] = true
case 'w':
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
case 's':
result.WriteString(`( " " | "\t" )`)
default:
result.WriteString(fmt.Sprintf(`"%c"`, next))
}
i += 2
default:
// Literal character
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
result.WriteString(fmt.Sprintf(`"%c" `, ch))
} else {
// Special char, try to escape
result.WriteString(fmt.Sprintf(`"%c" `, ch))
}
i++
}
}
return strings.TrimSpace(result.String()), true
}
func (c *converter) charClassToEBNF(class string) (string, bool) {
// Handle character classes like a-z, A-Z, 0-9
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
}
if class == "a-zA-Z0-9" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
}
if class == "a-z" {
return `"a" … "z"`, true
}
if class == "A-Z" {
return `"A" … "Z"`, true
}
if class == "0-9" {
c.usedTypes["digit"] = true
return "digit", true
}
// Try to parse range patterns
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
return "", false
}
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
var parts []string
for i, s := range schemas {
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
parts = append(parts, expr)
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) enumToExpr(values []interface{}) string {
var parts []string
for _, v := range values {
parts = append(parts, c.constToExpr(v))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) constToExpr(v interface{}) string {
switch val := v.(type) {
case string:
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
case float64:
if val == float64(int(val)) {
return fmt.Sprintf(`"%d"`, int(val))
}
return fmt.Sprintf(`"%v"`, val)
case bool:
if val {
return `"true"`
}
return `"false"`
case nil:
return `"null"`
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) resolveRef(ref string) string {
// Handle #/$defs/name references
if strings.HasPrefix(ref, "#/$defs/") {
defName := strings.TrimPrefix(ref, "#/$defs/")
return c.resolveDefRef(defName)
}
// Handle root recursion #
if ref == "#" {
return "root"
}
// Unknown ref format
c.usedTypes["string"] = true
return "string"
}
func (c *converter) resolveDefRef(defName string) string {
// Check if we've already defined this as a rule
ruleName := "def_" + defName
if c.definedRefs[defName] {
return ruleName
}
// Mark as defined to prevent infinite recursion
c.definedRefs[defName] = true
// Look up the definition
if c.definitions == nil {
c.usedTypes["string"] = true
return "string"
}
defSchema, ok := c.definitions[defName]
if !ok {
c.usedTypes["string"] = true
return "string"
}
// Generate the rule
expr := c.schemaToExpr(defSchema, ruleName)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
return ruleName
}
func (c *converter) getTypes(t interface{}) []string {
switch v := t.(type) {
case string:
return []string{v}
case []interface{}:
var types []string
for _, item := range v {
if s, ok := item.(string); ok {
types = append(types, s)
}
}
return types
}
return nil
}
func (c *converter) escapeString(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\"`)
return s
}
// Grammar converts a JSON Schema string into a compiled grammar.
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
ebnf, err := EBNF(schemaJSON)
if err != nil {
return nil, err
}
return grammar.ParseEBNF(ebnf, "root")
}

View File

@@ -0,0 +1,336 @@
//go:build mlx
package schema
import (
"testing"
gram "github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestJSONEBNF(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "simple object",
schema: `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`,
},
{
name: "with enum",
schema: `{
"type": "object",
"properties": {
"status": {"enum": ["active", "inactive", "pending"]}
},
"required": ["status"]
}`,
},
{
name: "array of objects",
schema: `{
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
}
}`,
},
{
name: "nested object",
schema: `{
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"email": {"type": "string"}
},
"required": ["email"]
}
},
"required": ["user"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
// Try to compile it
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}
func TestGrammarEngine(t *testing.T) {
schema := `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`
grammar, err := Grammar(schema)
if err != nil {
t.Fatalf("Grammar failed: %v", err)
}
vocab := []string{
"{", "}", "[", "]", ":", ",",
"\"name\"", "\"age\"", "\"test\"",
"\"", "a", "b", "c",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
" ", "\n",
"true", "false", "null",
}
engine, err := gram.NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("grammar.NewEngine failed: %v", err)
}
defer engine.Close()
logits := mlx.Ones(int32(len(vocab)))
mlx.Keep(logits)
// Test that we can apply mask
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
}
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
func TestOpenAIStructuredOutputs(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "anyOf union",
schema: `{
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "integer"}
]
}
},
"required": ["value"]
}`,
},
{
name: "nullable string via type array",
schema: `{
"type": "object",
"properties": {
"name": {"type": ["string", "null"]}
},
"required": ["name"]
}`,
},
{
name: "$ref with $defs",
schema: `{
"type": "object",
"properties": {
"person": {"$ref": "#/$defs/Person"}
},
"required": ["person"],
"$defs": {
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}
}
}`,
},
{
name: "const value",
schema: `{
"type": "object",
"properties": {
"type": {"const": "user"}
},
"required": ["type"]
}`,
},
{
name: "format date-time",
schema: `{
"type": "object",
"properties": {
"created": {"type": "string", "format": "date-time"}
},
"required": ["created"]
}`,
},
{
name: "format date",
schema: `{
"type": "object",
"properties": {
"birthday": {"type": "string", "format": "date"}
},
"required": ["birthday"]
}`,
},
{
name: "format email",
schema: `{
"type": "object",
"properties": {
"email": {"type": "string", "format": "email"}
},
"required": ["email"]
}`,
},
{
name: "format uuid",
schema: `{
"type": "object",
"properties": {
"id": {"type": "string", "format": "uuid"}
},
"required": ["id"]
}`,
},
{
name: "array with minItems maxItems",
schema: `{
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
"maxItems": 3
}
},
"required": ["tags"]
}`,
},
{
name: "deeply nested with refs",
schema: `{
"type": "object",
"properties": {
"company": {
"type": "object",
"properties": {
"name": {"type": "string"},
"employees": {
"type": "array",
"items": {"$ref": "#/$defs/Employee"}
}
},
"required": ["name", "employees"]
}
},
"required": ["company"],
"$defs": {
"Employee": {
"type": "object",
"properties": {
"name": {"type": "string"},
"role": {"enum": ["engineer", "manager", "intern"]}
},
"required": ["name", "role"]
}
}
}`,
},
{
name: "multiple refs same def",
schema: `{
"type": "object",
"properties": {
"from": {"$ref": "#/$defs/Address"},
"to": {"$ref": "#/$defs/Address"}
},
"required": ["from", "to"],
"$defs": {
"Address": {
"type": "object",
"properties": {
"city": {"type": "string"},
"zip": {"type": "string"}
},
"required": ["city", "zip"]
}
}
}`,
},
{
name: "oneOf variant",
schema: `{
"type": "object",
"properties": {
"result": {
"oneOf": [
{
"type": "object",
"properties": {"success": {"type": "boolean"}},
"required": ["success"]
},
{
"type": "object",
"properties": {"error": {"type": "string"}},
"required": ["error"]
}
]
}
},
"required": ["result"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}

105
x/grammar/terminal.go Normal file
View File

@@ -0,0 +1,105 @@
//go:build mlx
package grammar
import "unicode/utf8"
// terminalType distinguishes different kinds of grammar terminals
type terminalType int
const (
terminalLiteral terminalType = iota // Exact string: "true", "{"
terminalRange // Character range: [a-z], [0-9]
)
// terminal represents a compiled grammar terminal
type terminal struct {
ID int
Type terminalType
Pattern string // Original pattern from grammar
Unescaped string // Unescaped literal (for terminalLiteral)
LowRune rune // For unicode ranges: low bound
HighRune rune // For unicode ranges: high bound
}
// terminalMatch represents a terminal that matched at a position
type terminalMatch struct {
TerminalID int
Length int // Number of bytes consumed
}
// trieNode is a node in the literal matching trie
type trieNode struct {
children [256]*trieNode // Byte-indexed children
terminalID int // -1 if not accepting, else terminal ID
}
// terminalMatcher tests which terminals match at a position in a byte slice
type terminalMatcher struct {
// Trie for literal matching (fast path)
literalTrie *trieNode
// Range terminals (single-byte matches)
ranges []terminal
// All terminals for enumeration
terminals []terminal
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
patternToID map[string]int
}
// addLiteralToTrie adds a literal pattern to the trie
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
node := m.literalTrie
for i := 0; i < len(pattern); i++ {
c := pattern[i]
if node.children[c] == nil {
node.children[c] = &trieNode{terminalID: -1}
}
node = node.children[c]
}
node.terminalID = terminalID
}
// matchesAt returns all terminals that match at pos in data
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
if pos >= len(data) {
return nil
}
var matches []terminalMatch
// Check literal matches via trie
node := m.literalTrie
for i := pos; i < len(data) && node != nil; i++ {
c := data[i]
node = node.children[c]
if node != nil && node.terminalID >= 0 {
matches = append(matches, terminalMatch{
TerminalID: node.terminalID,
Length: i - pos + 1,
})
}
}
// Check range matches (unicode-aware)
r, runeLen := utf8.DecodeRune(data[pos:])
if r != utf8.RuneError {
for _, rng := range m.ranges {
if r >= rng.LowRune && r <= rng.HighRune {
matches = append(matches, terminalMatch{
TerminalID: rng.ID,
Length: runeLen,
})
}
}
}
return matches
}
// terminalCount returns the number of terminals
func (m *terminalMatcher) terminalCount() int {
return len(m.terminals)
}

38
x/imagegen/.gitignore vendored Normal file
View File

@@ -0,0 +1,38 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

Some files were not shown because too many files have changed in this diff Show More