Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
533541d97d macOS 2025-12-22 20:12:52 -05:00
259 changed files with 2324 additions and 43899 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -68,7 +68,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -372,17 +371,13 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -397,13 +392,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -536,7 +531,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -48,10 +32,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if(APPLE)
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -164,56 +147,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,39 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -184,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -204,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -3,7 +3,6 @@ package api
import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"math"
"os"
@@ -15,7 +14,6 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model"
)
@@ -229,79 +227,13 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"`
}
// ToolCallFunctionArguments holds tool call arguments in insertion order.
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
if t == nil || t.om == nil {
return "{}"
}
bts, _ := json.Marshal(t.om)
bts, _ := json.Marshal(t)
return string(bts)
}
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
@@ -350,78 +282,13 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt))
}
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties map[string]ToolProperty `json:"properties,omitempty"`
}
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
@@ -470,11 +337,11 @@ func mapToTypeScriptType(jsonType string) string {
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]ToolProperty `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {

View File

@@ -11,24 +11,6 @@ import (
"github.com/stretchr/testify/require"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
@@ -327,9 +309,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
input: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
},
@@ -337,9 +319,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
name: "no required",
input: ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
},
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
},
@@ -357,7 +339,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"message": "hi"}),
Arguments: ToolCallFunctionArguments{"message": "hi"},
}
data, err := json.Marshal(fn)
@@ -547,7 +529,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"address": {
Type: PropertyType{"string"},
Description: "Street address",
@@ -556,7 +538,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
Type: PropertyType{"string"},
Description: "City name",
},
}),
},
},
},
{
@@ -584,22 +566,22 @@ func TestToolPropertyNestedProperties(t *testing.T) {
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"location": {
Type: PropertyType{"object"},
Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"coordinates": {
Type: PropertyType{"object"},
Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}),
},
},
}),
},
},
}),
},
},
},
}
@@ -609,13 +591,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
assert.Equal(t, tt.expected, prop)
// Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop)
@@ -624,10 +600,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
var prop2 ToolProperty
err = json.Unmarshal(data, &prop2)
require.NoError(t, err)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
assert.Equal(t, tt.expected, prop2)
})
}
}
@@ -643,12 +616,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
}),
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
@@ -665,7 +638,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s
return s
}(),
Properties: testPropsMap(map[string]ToolProperty{}),
Properties: map[string]ToolProperty{},
},
expected: "",
},
@@ -678,235 +651,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@@ -209,6 +209,9 @@ func main() {
st := &store.Store{}
// Initialize native settings with store
SetSettingsStore(st)
// Enable CORS in development mode
if devMode {
os.Setenv("OLLAMA_CORS", "1")
@@ -253,22 +256,27 @@ func main() {
done <- osrv.Run(octx)
}()
restartServer := func() {
ocancel()
<-done
octx, ocancel = context.WithCancel(ctx)
go func() {
done <- osrv.Run(octx)
}()
}
uiServer := ui.Server{
Token: token,
Restart: func() {
ocancel()
<-done
octx, ocancel = context.WithCancel(ctx)
go func() {
done <- osrv.Run(octx)
}()
},
Token: token,
Restart: restartServer,
Store: st,
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
}
// Set restart callback for native settings
SetRestartCallback(restartServer)
srv := &http.Server{
Handler: uiServer.Handler(),
}

View File

@@ -1,5 +1,6 @@
#import "app_darwin.h"
#import "menu.h"
#import "settings_darwin.h"
#import "../../updater/updater_darwin.h"
#import <AppKit/AppKit.h>
#import <Cocoa/Cocoa.h>
@@ -252,7 +253,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)settingsUI {
[self uiRequest:@"/settings"];
openNativeSettings();
}
- (void)openUI {

View File

@@ -0,0 +1,438 @@
//go:build darwin
package main
/*
#cgo CFLAGS: -x objective-c
#cgo LDFLAGS: -framework Cocoa
#include <stdlib.h>
#include "settings_darwin.h"
*/
import "C"
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"encoding/pem"
"fmt"
"log/slog"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
"unsafe"
"golang.org/x/crypto/ssh"
appauth "github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
)
// settingsStore is a reference to the app's store for settings
var settingsStore *store.Store
// SetSettingsStore sets the store reference for settings callbacks
func SetSettingsStore(s *store.Store) {
settingsStore = s
}
//export getSettingsExpose
func getSettingsExpose() C.bool {
if settingsStore == nil {
return C.bool(false)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.bool(false)
}
return C.bool(settings.Expose)
}
//export setSettingsExpose
func setSettingsExpose(expose C.bool) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Expose = bool(expose)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsBrowser
func getSettingsBrowser() C.bool {
if settingsStore == nil {
return C.bool(false)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.bool(false)
}
return C.bool(settings.Browser)
}
//export setSettingsBrowser
func setSettingsBrowser(browser C.bool) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Browser = bool(browser)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsModels
func getSettingsModels() *C.char {
if settingsStore == nil {
return C.CString(envconfig.Models())
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.CString(envconfig.Models())
}
if settings.Models == "" {
return C.CString(envconfig.Models())
}
return C.CString(settings.Models)
}
//export setSettingsModels
func setSettingsModels(path *C.char) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Models = C.GoString(path)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsContextLength
func getSettingsContextLength() C.int {
if settingsStore == nil {
return C.int(4096)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.int(4096)
}
if settings.ContextLength <= 0 {
return C.int(4096)
}
return C.int(settings.ContextLength)
}
//export setSettingsContextLength
func setSettingsContextLength(length C.int) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.ContextLength = int(length)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
// restartCallback is set by the app to restart the ollama server
var restartCallback func()
// SetRestartCallback sets the function to call when settings change requires a restart
func SetRestartCallback(cb func()) {
restartCallback = cb
}
//export restartOllamaServer
func restartOllamaServer() {
if restartCallback != nil {
slog.Info("restarting ollama server due to settings change")
go restartCallback()
}
}
// hasOllamaKey checks if the user has an Ollama key file
func hasOllamaKey() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
_, err = os.Stat(keyPath)
return err == nil
}
// ensureKeypair generates a new keypair if one doesn't exist
func ensureKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
// Check if key already exists
if _, err := os.Stat(privKeyPath); err == nil {
return nil // Key exists
}
// Generate new keypair
slog.Info("generating new keypair for ollama account")
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Marshal private key
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil {
return fmt.Errorf("failed to marshal private key: %w", err)
}
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600); err != nil {
return fmt.Errorf("failed to write private key: %w", err)
}
// Write public key
sshPubKey, err := ssh.NewPublicKey(pubKey)
if err != nil {
return fmt.Errorf("failed to create ssh public key: %w", err)
}
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
if err := os.WriteFile(pubKeyPath, pubKeyBytes, 0o644); err != nil {
return fmt.Errorf("failed to write public key: %w", err)
}
slog.Info("keypair generated successfully")
return nil
}
// userResponse matches the API response from ollama.com/api/me
type userResponse struct {
Name string `json:"name"`
Email string `json:"email"`
Plan string `json:"plan"`
AvatarURL string `json:"avatarurl"`
}
// fetchUserFromAPI fetches user data from ollama.com using signed request
func fetchUserFromAPI() (*userResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
signString := fmt.Sprintf("POST,/api/me?ts=%s", timestamp)
signature, err := auth.Sign(ctx, []byte(signString))
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
}
endpoint := fmt.Sprintf("https://ollama.com/api/me?ts=%s", timestamp)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var user userResponse
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// Make avatar URL absolute
if user.AvatarURL != "" && !strings.HasPrefix(user.AvatarURL, "http") {
user.AvatarURL = "https://ollama.com/" + user.AvatarURL
}
// Cache the avatar URL
cachedAvatarURL = user.AvatarURL
// Cache the user data
if settingsStore != nil {
storeUser := store.User{
Name: user.Name,
Email: user.Email,
Plan: user.Plan,
}
if err := settingsStore.SetUser(storeUser); err != nil {
slog.Warn("failed to cache user", "error", err)
}
}
return &user, nil
}
//export getAccountName
func getAccountName() *C.char {
// Only return cached data - never block on network
if settingsStore == nil {
return C.CString("")
}
user, err := settingsStore.User()
if err != nil || user == nil {
return C.CString("")
}
return C.CString(user.Name)
}
// cachedAvatarURL stores the avatar URL from the last API fetch
var cachedAvatarURL string
//export getAccountAvatarURL
func getAccountAvatarURL() *C.char {
return C.CString(cachedAvatarURL)
}
//export getAccountEmail
func getAccountEmail() *C.char {
if settingsStore != nil {
user, err := settingsStore.User()
if err == nil && user != nil {
return C.CString(user.Email)
}
}
return C.CString("")
}
//export getAccountPlan
func getAccountPlan() *C.char {
if settingsStore != nil {
user, err := settingsStore.User()
if err == nil && user != nil {
return C.CString(user.Plan)
}
}
return C.CString("")
}
//export signOutAccount
func signOutAccount() {
if settingsStore != nil {
if err := settingsStore.ClearUser(); err != nil {
slog.Error("failed to clear user", "error", err)
}
}
// Also remove the key file
home, err := os.UserHomeDir()
if err != nil {
slog.Error("failed to get home dir", "error", err)
return
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
if err := os.Remove(keyPath); err != nil && !os.IsNotExist(err) {
slog.Error("failed to remove key file", "error", err)
}
}
//export openConnectUrl
func openConnectUrl() {
// Ensure keypair exists (generate if needed)
if err := ensureKeypair(); err != nil {
slog.Error("failed to ensure keypair", "error", err)
// Fallback to basic connect page
cmd := exec.Command("open", "https://ollama.com/connect")
cmd.Start()
return
}
// Build connect URL with public key
connectURL, err := appauth.BuildConnectURL("https://ollama.com")
if err != nil {
slog.Error("failed to build connect URL", "error", err)
// Fallback to basic connect page
connectURL = "https://ollama.com/connect"
}
cmd := exec.Command("open", connectURL)
if err := cmd.Start(); err != nil {
slog.Error("failed to open connect URL", "error", err)
}
}
//export refreshAccountFromAPI
func refreshAccountFromAPI() {
if !hasOllamaKey() {
return
}
_, err := fetchUserFromAPI()
if err != nil {
slog.Debug("failed to refresh account", "error", err)
}
}
//export prefetchAccountData
func prefetchAccountData() {
// Run in background goroutine to not block app startup
go func() {
if !hasOllamaKey() {
return
}
_, err := fetchUserFromAPI()
if err != nil {
slog.Debug("failed to prefetch account data", "error", err)
} else {
slog.Debug("prefetched account data successfully")
}
}()
}
// OpenNativeSettings opens the native settings window
func OpenNativeSettings() {
C.openNativeSettings()
}
// Ensure the CString is freed (caller must free)
func freeCString(s *C.char) {
C.free(unsafe.Pointer(s))
}

View File

@@ -0,0 +1,38 @@
#import <Cocoa/Cocoa.h>
@interface SettingsWindowController : NSWindowController <NSWindowDelegate>
// General tab
@property(nonatomic, strong) NSButton *exposeCheckbox;
@property(nonatomic, strong) NSButton *browserCheckbox;
@property(nonatomic, strong) NSSlider *contextLengthSlider;
// Models tab
@property(nonatomic, strong) NSPathControl *modelsPathControl;
@property(nonatomic, strong) NSButton *modelsPathButton;
// Account tab
@property(nonatomic, strong) NSView *avatarView;
@property(nonatomic, strong) NSTextField *avatarInitialLabel;
@property(nonatomic, strong) NSImageView *avatarImageView;
@property(nonatomic, strong) NSTextField *accountNameLabel;
@property(nonatomic, strong) NSTextField *accountEmailLabel;
@property(nonatomic, strong) NSButton *manageButton;
@property(nonatomic, strong) NSButton *signOutButton;
@property(nonatomic, strong) NSButton *signInButton;
@property(nonatomic, strong) NSView *signedInContainer;
@property(nonatomic, strong) NSView *signedOutContainer;
// Plan section
@property(nonatomic, strong) NSView *planContainer;
@property(nonatomic, strong) NSTextField *planNameLabel;
@property(nonatomic, strong) NSButton *upgradeButton;
@property(nonatomic, strong) NSButton *viewUsageButton;
+ (instancetype)sharedController;
- (void)showSettings;
@end
// Go callbacks for settings
void openNativeSettings(void);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
//go:build windows
package main
import "github.com/ollama/ollama/app/store"
// SetSettingsStore sets the store reference for settings callbacks (stub for Windows)
func SetSettingsStore(s *store.Store) {
// TODO: Implement Windows native settings
}
// SetRestartCallback sets the function to call when settings change requires a restart (stub for Windows)
func SetRestartCallback(cb func()) {
// TODO: Implement Windows native settings
}

View File

@@ -147,7 +147,6 @@ export const highlighterPromise = createHighlighter({
"c",
"cpp",
"sql",
"swift",
"yaml",
"markdown",
],

View File

@@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed
toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil)
@@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok {
@@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""),
}
tool.Function.Parameters.Properties.Set(propName, prop)
tool.Function.Parameters.Properties[propName] = prop
}
}
}

View File

@@ -45,9 +45,6 @@ import (
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -98,11 +95,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -464,7 +456,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -526,19 +517,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
@@ -565,11 +543,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@@ -673,11 +646,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -1785,12 +1754,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -1547,79 +1547,6 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -40,7 +40,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
}
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,6 +3,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -53,7 +55,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [
{
"function": {
"name": "get_weather",
"name": "get_temperature",
"arguments": {
"city": "Toronto"
}
}
},
}
]
},
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_weather"
"tool_name": "get_temperature",
}
],
"stream": false,

View File

@@ -1,406 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -277,8 +277,6 @@ curl -X POST http://localhost:11434/v1/chat/completions \
### `/v1/responses`
> Note: Added in Ollama v0.13.3
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
#### Supported features

View File

@@ -36,6 +36,7 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}],
"stream": false
}'
"
```
</Tab>
<Tab title="Python">

View File

@@ -32,9 +32,7 @@
"codeblocks": "system"
},
"contextual": {
"options": [
"copy"
]
"options": ["copy"]
},
"navbar": {
"links": [
@@ -54,9 +52,7 @@
"display": "simple"
},
"examples": {
"languages": [
"curl"
]
"languages": ["curl"]
}
},
"redirects": [
@@ -101,7 +97,6 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -144,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu).
Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library:
> **NOTE:**
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> **NOTE:**
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server)
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -1,69 +0,0 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

3
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,3 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}

4
go.mod
View File

@@ -28,7 +28,6 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
@@ -37,8 +36,6 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
@@ -48,7 +45,6 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

9
go.sum
View File

@@ -14,11 +14,7 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -127,7 +123,6 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
@@ -148,8 +143,6 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -214,8 +207,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=

View File

@@ -11,15 +11,6 @@ import (
"github.com/ollama/ollama/api"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second
@@ -66,12 +57,12 @@ func TestAPIToolCalling(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
}),
},
},
},
},

View File

@@ -1,94 +0,0 @@
// Package orderedmap provides a generic ordered map that maintains insertion order.
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
package orderedmap
import (
"encoding/json"
"iter"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
// Map is a generic ordered map that maintains insertion order.
type Map[K comparable, V any] struct {
om *orderedmap.OrderedMap[K, V]
}
// New creates a new empty ordered map.
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
om: orderedmap.New[K, V](),
}
}
// Get retrieves a value by key.
func (m *Map[K, V]) Get(key K) (V, bool) {
if m == nil || m.om == nil {
var zero V
return zero, false
}
return m.om.Get(key)
}
// Set sets a key-value pair. If the key already exists, its value is updated
// but its position in the iteration order is preserved. If the key is new,
// it is appended to the end.
func (m *Map[K, V]) Set(key K, value V) {
if m == nil {
return
}
if m.om == nil {
m.om = orderedmap.New[K, V]()
}
m.om.Set(key, value)
}
// Len returns the number of entries.
func (m *Map[K, V]) Len() int {
if m == nil || m.om == nil {
return 0
}
return m.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (m *Map[K, V]) All() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
if m == nil || m.om == nil {
return
}
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
// ToMap converts to a regular Go map.
// Note: The resulting map does not preserve order.
func (m *Map[K, V]) ToMap() map[K]V {
if m == nil || m.om == nil {
return nil
}
result := make(map[K]V, m.om.Len())
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
result[pair.Key] = pair.Value
}
return result
}
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
if m == nil || m.om == nil {
return []byte("null"), nil
}
return json.Marshal(m.om)
}
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
// order of keys in the JSON input.
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
m.om = orderedmap.New[K, V]()
return json.Unmarshal(data, &m.om)
}

View File

@@ -1,348 +0,0 @@
package orderedmap
import (
"encoding/json"
"slices"
"testing"
)
func TestMap_BasicOperations(t *testing.T) {
m := New[string, int]()
// Test empty map
if m.Len() != 0 {
t.Errorf("expected Len() = 0, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on empty map to return false")
}
if v != 0 {
t.Errorf("expected zero value, got %d", v)
}
// Test Set and Get
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
if m.Len() != 3 {
t.Errorf("expected Len() = 3, got %d", m.Len())
}
v, ok = m.Get("a")
if !ok || v != 1 {
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("b")
if !ok || v != 2 {
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("c")
if !ok || v != 3 {
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
}
// Test updating existing key preserves position
m.Set("a", 10)
v, ok = m.Get("a")
if !ok || v != 10 {
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
}
if m.Len() != 3 {
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
}
}
func TestMap_InsertionOrderPreserved(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
m.Set("b", 4)
// Verify iteration order matches insertion order
var keys []string
var values []int
for k, v := range m.All() {
keys = append(keys, k)
values = append(values, v)
}
expectedKeys := []string{"z", "a", "m", "b"}
expectedValues := []int{1, 2, 3, 4}
if !slices.Equal(keys, expectedKeys) {
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
}
if !slices.Equal(values, expectedValues) {
t.Errorf("expected values %v, got %v", expectedValues, values)
}
}
func TestMap_UpdatePreservesPosition(t *testing.T) {
m := New[string, int]()
m.Set("first", 1)
m.Set("second", 2)
m.Set("third", 3)
// Update middle element
m.Set("second", 20)
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
// Order should still be first, second, third
expected := []string{"first", "second", "third"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
// JSON should preserve insertion order, not alphabetical
expected := `{"z":1,"a":2,"m":3}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
// JSON with non-alphabetical key order
jsonData := `{"z":1,"a":2,"m":3}`
m := New[string, int]()
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
// Verify iteration order matches JSON order
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
expected := []string{"z", "a", "m"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_JSONRoundTrip(t *testing.T) {
// Test that unmarshal -> marshal produces identical JSON
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
m := New[string, string]()
if err := json.Unmarshal([]byte(original), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != original {
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
}
}
func TestMap_ToMap(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
regular := m.ToMap()
if len(regular) != 2 {
t.Errorf("expected len 2, got %d", len(regular))
}
if regular["a"] != 1 {
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
}
if regular["b"] != 2 {
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
}
}
func TestMap_NilSafety(t *testing.T) {
var m *Map[string, int]
// All operations should be safe on nil
if m.Len() != 0 {
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on nil map to return false")
}
if v != 0 {
t.Errorf("expected zero value from nil map, got %d", v)
}
// Set on nil is a no-op
m.Set("a", 1)
if m.Len() != 0 {
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
}
// All returns empty iterator
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
if len(keys) != 0 {
t.Errorf("expected empty iteration on nil map, got %v", keys)
}
// ToMap returns nil
if m.ToMap() != nil {
t.Error("expected ToMap to return nil on nil map")
}
// MarshalJSON returns null
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "null" {
t.Errorf("expected null, got %s", string(data))
}
}
func TestMap_EmptyMapMarshal(t *testing.T) {
m := New[string, int]()
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "{}" {
t.Errorf("expected {}, got %s", string(data))
}
}
func TestMap_NestedValues(t *testing.T) {
m := New[string, any]()
m.Set("string", "hello")
m.Set("number", 42)
m.Set("bool", true)
m.Set("nested", map[string]int{"x": 1})
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_AllIteratorEarlyExit(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
m.Set("d", 4)
// Collect only first 2
var keys []string
for k := range m.All() {
keys = append(keys, k)
if len(keys) == 2 {
break
}
}
expected := []string{"a", "b"}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_IntegerKeys(t *testing.T) {
m := New[int, string]()
m.Set(3, "three")
m.Set(1, "one")
m.Set(2, "two")
var keys []int
for k := range m.All() {
keys = append(keys, k)
}
// Should preserve insertion order, not numerical order
expected := []int{3, 1, 2}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_UnmarshalIntoExisting(t *testing.T) {
m := New[string, int]()
m.Set("existing", 999)
// Unmarshal should replace contents
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
_, ok := m.Get("existing")
if ok {
t.Error("existing key should be gone after unmarshal")
}
v, ok := m.Get("new")
if !ok || v != 1 {
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
}
}
func TestMap_LargeOrderPreservation(t *testing.T) {
m := New[string, int]()
// Create many keys in specific order
keys := make([]string, 100)
for i := range 100 {
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
if i >= 26 {
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
}
}
for i, k := range keys {
m.Set(k, i)
}
// Verify order preserved
var resultKeys []string
for k := range m.All() {
resultKeys = append(resultKeys, k)
}
if !slices.Equal(keys, resultKeys) {
t.Error("large map should preserve insertion order")
}
}

View File

@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 ++++++++++
9 files changed, 1005 insertions(+), 17 deletions(-)
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 976 insertions(+), 17 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
@@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6852d2e20..334a30135 100644
index 6852d2e20..48cdb1dcf 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -109,7 +109,7 @@ index 6852d2e20..334a30135 100644
+
+#if defined(GGML_USE_HIP)
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c58..dba8f4695 100644
index fe57d4c58..1c07e767a 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
@@ -216,7 +216,7 @@ index fe57d4c58..dba8f4695 100644
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
+GGML_API void ggml_nvml_release();
+GGML_API int ggml_hip_mgmt_init();
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
+GGML_API void ggml_hip_mgmt_release();
+
#ifdef __cplusplus
@@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
/* .async = */ true,
/* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 5349bce24..0103fd03a 100644
index 5349bce24..d43d46d1d 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -236,6 +236,7 @@ class vk_memory_logger;
@@ -334,7 +334,7 @@ index 5349bce24..0103fd03a 100644
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -505,10 +505,10 @@ index 5349bce24..0103fd03a 100644
}
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..23c765806
index 000000000..c1949b899
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,558 @@
@@ -0,0 +1,529 @@
+#include "ggml.h"
+#include "ggml-impl.h"
+
@@ -842,7 +842,7 @@ index 000000000..23c765806
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
+ if (gpu != NULL) gpu->pVtbl->Release(gpu)
+
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock);
+ if (adlx.handle == NULL) {
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -966,16 +966,13 @@ index 000000000..23c765806
+ return 0;
+}
+void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+
+
+ glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+
@@ -1009,6 +1006,7 @@ index 000000000..23c765806
+
+ uint64_t memory;
+ totalFileStream >> memory;
+ *total = memory;
+
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
@@ -1021,33 +1019,6 @@ index 000000000..23c765806
+
+ uint64_t memoryUsed;
+ usedFileStream >> memoryUsed;
+
+ if (is_integrated_gpu) {
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gtt;
+ totalFileStream >> gtt;
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gttUsed;
+ usedFileStream >> gttUsed;
+ memory += gtt;
+ memoryUsed += gttUsed;
+ }
+
+ *total = memory;
+ *free = memory - memoryUsed;
+
+ file.close();

View File

@@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index dba8f4695..7e17032c7 100644
index 1c07e767a..0da3e065b 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release();
+GGML_API int ggml_dxgi_pdh_init();
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
@@ -38,7 +38,7 @@ index dba8f4695..7e17032c7 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 0103fd03a..9cc4ebdef 100644
index d43d46d1d..df79f9f79 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -10,7 +10,7 @@ fallback to cpu
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 334a30135..5c9dfd032 100644
index 48cdb1dcf..3102d7ea7 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g

View File

@@ -1,152 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -1,607 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -19,40 +19,6 @@ import (
"github.com/ollama/ollama/openai"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return cmp.Equal(a.ToMap(), b.ToMap())
})
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
@@ -255,10 +221,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -295,10 +261,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -334,10 +300,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -374,10 +340,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -414,10 +380,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id_abc",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -460,10 +426,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id",
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
}),
},
},
},
},
@@ -528,7 +494,7 @@ func TestChatMiddleware(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state",
@@ -537,7 +503,7 @@ func TestChatMiddleware(t *testing.T) {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
}),
},
},
},
},
@@ -592,7 +558,7 @@ func TestChatMiddleware(t *testing.T) {
}
return
}
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" {
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {

View File

@@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
#if defined(GGML_USE_HIP)
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release();
GGML_API int ggml_dxgi_pdh_init();
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);

View File

@@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
switch (props2.properties.vendorID) {
case VK_VENDOR_ID_AMD:
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
if (gpus != NULL) gpus->pVtbl->Release(gpus); \
if (gpu != NULL) gpu->pVtbl->Release(gpu)
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
std::lock_guard<std::mutex> lock(ggml_adlx_lock);
if (adlx.handle == NULL) {
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -455,16 +455,13 @@ int ggml_hip_mgmt_init() {
return 0;
}
void ggml_hip_mgmt_release() {}
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
const std::string drmTotalMemoryFile = "mem_info_vram_total";
const std::string drmUsedMemoryFile = "mem_info_vram_used";
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
glob_t glob_result;
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
@@ -498,6 +495,7 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memory;
totalFileStream >> memory;
*total = memory;
std::string usedFile = dir + "/" + drmUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
@@ -510,33 +508,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memoryUsed;
usedFileStream >> memoryUsed;
if (is_integrated_gpu) {
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
std::ifstream totalFileStream(totalFile.c_str());
if (!totalFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gtt;
totalFileStream >> gtt;
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
if (!usedFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gttUsed;
usedFileStream >> gttUsed;
memory += gtt;
memoryUsed += gttUsed;
}
*total = memory;
*free = memory - memoryUsed;
file.close();

View File

@@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "London",
}),
},
},
},
},
@@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true, "threshold": 0.95},
"count": 42.0,
}),
},
},
},
},
@@ -238,7 +238,7 @@ This is line 3</think>Final response here.`,
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
@@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test_tool",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
}),
},
},
},
}
@@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
}
@@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
@@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
expectError: false,
@@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true},
"count": 42.0,
}),
},
},
},
expectError: false,
@@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "no_args_tool",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
expectError: false,
@@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
expectError: false,
@@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
"units": "metric",
}),
},
},
},
expectError: false,
@@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "complex_tool",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"nested": map[string]any{
"deep": map[string]any{
"value": 123.0,
},
},
}),
},
},
},
expectError: false,
@@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
}
})

View File

@@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "London",
}),
},
},
},
},
@@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"item1", "item2"},
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
}),
},
},
},
},
@@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
}),
},
},
},
},
@@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"query": "北京天气",
"language": "中文",
}),
},
},
},
},
@@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "execute_command",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"",
"path": "/home/user",
}),
},
},
},
},
@@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
@@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"x": float64(42),
"y": float64(24),
}),
},
},
},
},
@@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
@@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) {
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
if diff := cmp.Diff(tools, returnedTools); diff != "" {
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
}
@@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
}),
},
},
},
},
@@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"城市": "北京",
}),
},
},
},
},
@@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"",
"path": "/home/user",
}),
},
},
},
},
@@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"x": 3.14,
"y": float64(42),
"enabled": true,
}),
},
},
},
},
@@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
}),
},
},
},
},
@@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
}
})

View File

@@ -166,7 +166,7 @@ func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error
// parseArguments parses the key:value,key:value format
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
args := make(api.ToolCallFunctionArguments)
if argsStr == "" {
return args
}
@@ -185,7 +185,7 @@ func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctio
value := part[colonIdx+1:]
// Parse the value
args.Set(key, p.parseValue(value))
args[key] = p.parseValue(value)
}
return args

View File

@@ -3,7 +3,6 @@ package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
@@ -37,9 +36,9 @@ func TestFunctionGemmaParser(t *testing.T) {
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -48,7 +47,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
@@ -67,7 +66,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
@@ -85,7 +84,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}),
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
},
},
},
@@ -103,7 +102,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "set_flag",
Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}),
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
},
},
},
@@ -125,13 +124,13 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "London"}),
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
@@ -153,7 +152,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}),
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
},
},
},
@@ -174,9 +173,9 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "update",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"data": map[string]any{"name": "test", "value": int64(42)},
}),
},
},
},
},
@@ -199,7 +198,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -225,7 +224,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{"value": 3.14}),
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
},
},
},
@@ -243,7 +242,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -262,7 +261,7 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "greet",
Arguments: testArgs(map[string]any{"name": "日本語"}),
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
},
},
},
@@ -282,11 +281,11 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"query": "test",
"limit": int64(10),
"offset": int64(0),
}),
},
},
},
},
@@ -309,14 +308,14 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "create",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"config": map[string]any{
"settings": map[string]any{
"enabled": true,
"name": "test",
},
},
}),
},
},
},
},
@@ -346,13 +345,13 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
},
},
},
@@ -373,13 +372,13 @@ func TestFunctionGemmaParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "first",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Name: "second",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: api.ToolCallFunctionArguments{},
},
},
},
@@ -412,9 +411,7 @@ func TestFunctionGemmaParser(t *testing.T) {
}
assert.Equal(t, tt.expectedText, allContent)
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
t.Errorf("calls mismatch (-want +got):\n%s", diff)
}
assert.Equal(t, tt.expectedCalls, allCalls)
})
}
}

View File

@@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(before), &args); err != nil {
var data map[string]any
if err := json.Unmarshal([]byte(before), &data); err != nil {
// todo - throw a better error
return "", "", calls, err
}
@@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Arguments: args,
Arguments: api.ToolCallFunctionArguments(data),
},
}
calls = append(calls, call)

View File

@@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
toolCall.Function.Name = fnMatch[1]
// Extract parameters
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
for _, match := range paramMatches {
if len(match) >= 3 {
@@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
paramValue := strings.TrimSpace(match[2])
// Try to parse as typed value based on tool definition
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
}
}
@@ -244,11 +244,9 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any
// Find the matching tool to get parameter type
var paramType api.PropertyType
for _, tool := range p.tools {
if tool.Function.Parameters.Properties != nil {
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
paramType = prop.Type
break
}
if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
paramType = prop.Type
break
}
}

View File

@@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
},
@@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
Arguments: map[string]any{"city": "NYC"},
},
},
},
@@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
@@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
Arguments: map[string]any{"city": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
Arguments: map[string]any{"city": "New York"},
},
},
},
@@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
},
@@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{"query": "test"}),
Arguments: map[string]any{"query": "test"},
},
},
},
@@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
},
},
},
@@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) {
name: "tool call with no function name - returns empty tool call",
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
},
{
name: "content with newlines preserved",
@@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{"value": "42"}),
Arguments: map[string]any{"value": "42"},
},
},
},
@@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) {
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
@@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
},
@@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
Arguments: map[string]any{"city": "NYC"},
},
},
},
@@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: map[string]any{},
},
},
},
@@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
@@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{"query": "test query"}),
Arguments: map[string]any{"query": "test query"},
},
},
},
@@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
Arguments: map[string]any{"city": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
Arguments: map[string]any{"city": "New York"},
},
},
},
@@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
},
},
},
@@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
Arguments: map[string]any{},
},
},
},
@@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: testArgs(map[string]any{"name": ""}),
Arguments: map[string]any{"name": ""},
},
},
},
@@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
@@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
p := &Nemotron3NanoParser{}
returnedTools := p.Init(tools, nil, nil)
if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" {
if diff := cmp.Diff(returnedTools, tools); diff != "" {
t.Errorf("tools mismatch (-got +want):\n%s", diff)
}
@@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
}
if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" {
if diff := cmp.Diff(calls, expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
}

View File

@@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
// parseOlmo3Arguments parses comma-separated key=value pairs
// Handles nested parentheses, brackets, braces, and quoted strings
func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
args := api.NewToolCallFunctionArguments()
func parseOlmo3Arguments(s string) (map[string]any, error) {
args := make(map[string]any)
s = strings.TrimSpace(s)
if s == "" {
return args, nil
@@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
// Find the first = sign
eqIdx := strings.Index(part, "=")
if eqIdx == -1 {
return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part)
return nil, fmt.Errorf("invalid argument format: %s", part)
}
key := strings.TrimSpace(part[:eqIdx])
@@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
value, err := parseOlmo3Value(valueStr)
if err != nil {
return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err)
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
}
args.Set(key, value)
args[key] = value
}
return args, nil

View File

@@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
Arguments: map[string]any{"location": "San Francisco"},
},
},
},
@@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "NYC"}),
Arguments: map[string]any{"location": "NYC"},
},
},
},
@@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
"date": "2024-01-15",
}),
},
},
},
},
@@ -70,13 +70,13 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "New York"}),
Arguments: map[string]any{"location": "New York"},
},
},
},
@@ -88,7 +88,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "set_temperature",
Arguments: testArgs(map[string]any{"value": int64(72)}),
Arguments: map[string]any{"value": int64(72)},
},
},
},
@@ -100,7 +100,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "set_price",
Arguments: testArgs(map[string]any{"amount": 19.99}),
Arguments: map[string]any{"amount": 19.99},
},
},
},
@@ -112,7 +112,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "toggle_setting",
Arguments: testArgs(map[string]any{"enabled": true}),
Arguments: map[string]any{"enabled": true},
},
},
},
@@ -124,7 +124,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "clear_value",
Arguments: testArgs(map[string]any{"field": nil}),
Arguments: map[string]any{"field": nil},
},
},
},
@@ -136,7 +136,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "process_items",
Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}),
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
},
},
},
@@ -148,12 +148,12 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "update_config",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"settings": map[string]any{
"theme": "dark",
"fontSize": int64(14),
},
}),
},
},
},
},
@@ -165,7 +165,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "create_request",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"data": map[string]any{
"user": map[string]any{
"name": "John",
@@ -173,7 +173,7 @@ get_weather(location="New York")</function_calls>`,
},
"active": true,
},
}),
},
},
},
},
@@ -185,7 +185,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "get_current_time",
Arguments: testArgs(map[string]any{}),
Arguments: map[string]any{},
},
},
},
@@ -197,7 +197,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{"query": "hello world"}),
Arguments: map[string]any{"query": "hello world"},
},
},
},
@@ -209,7 +209,7 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{"query": `say "hello"`}),
Arguments: map[string]any{"query": `say "hello"`},
},
},
},
@@ -221,11 +221,11 @@ get_weather(location="New York")</function_calls>`,
{
Function: api.ToolCallFunction{
Name: "create_user",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"name": "John",
"age": int64(30),
"active": true,
}),
},
},
},
},
@@ -257,7 +257,7 @@ get_weather(location="New York")</function_calls>`,
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
@@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}),
Arguments: map[string]any{"location": "SF"},
},
},
},
@@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "NYC"}),
Arguments: map[string]any{"location": "NYC"},
},
},
},
@@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: testArgs(map[string]any{}),
Arguments: map[string]any{},
},
},
},
@@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
@@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}),
Arguments: map[string]any{"location": "SF"},
},
},
},
@@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "send_email",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"to": "user@example.com",
"subject": "Hello",
"body": "Test message",
}),
},
},
},
},
@@ -407,13 +407,13 @@ get_time(timezone="PST")`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}),
Arguments: map[string]any{"location": "SF"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "PST"}),
Arguments: map[string]any{"timezone": "PST"},
},
},
},
@@ -437,7 +437,7 @@ get_time(timezone="PST")`,
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" {
if diff := cmp.Diff(calls, tt.expected); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})

View File

@@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
}
}
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok {
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
@@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
}
}
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType))
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
}
return toolCall, nil

View File

@@ -11,7 +11,7 @@ import (
func tool(name string, props map[string]api.ToolProperty) api.Tool {
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
t.Function.Parameters.Type = "object"
t.Function.Parameters.Properties = testPropsMap(props)
t.Function.Parameters.Properties = props
return t
}
@@ -369,10 +369,10 @@ celsius
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_temperature",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "San Francisco",
"unit": "celsius",
}),
},
},
},
},
@@ -390,10 +390,10 @@ celsius
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
}),
},
},
},
},
@@ -415,10 +415,10 @@ San Francisco
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
}),
},
},
},
},
@@ -449,12 +449,12 @@ true
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"x": 3.14,
"y": 42,
"enabled": true,
"items": []any{"a", "b", "c"},
}),
},
},
},
},
@@ -470,9 +470,9 @@ ls && echo "done"
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"done\"",
}),
},
},
},
},
@@ -487,9 +487,9 @@ ls && echo "a > b and a < b"
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
}),
},
},
},
},
@@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا",
}),
},
},
},
},
@@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا
if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err)
}
if !toolCallEqual(gotToolCall, step.wantToolCall) {
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
}
}

View File

@@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "San Francisco, CA",
"unit": "fahrenheit",
}),
},
},
},
},
@@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
}),
},
},
},
},
@@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
}),
},
},
},
},
@@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
"items": []any{"a", "b", "c"},
}),
},
},
},
},
@@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"done\"",
}),
},
},
},
},
@@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
}),
},
},
},
},
@@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا",
}),
},
},
},
},
@@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err)
}
if !toolCallEqual(gotToolCall, step.wantToolCall) {
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
}
}

View File

@@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location": "San Francisco, CA",
"unit": "fahrenheit",
}),
},
},
},
},
@@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
}),
},
},
},
},
@@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
}),
},
},
},
},
@@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
"items": []any{"a", "b", "c"},
}),
},
},
},
},
@@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"done\"",
}),
},
},
},
},
@@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
}),
},
},
},
},
@@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا",
}),
},
},
},
},
@@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err)
}
if !toolCallEqual(gotToolCall, step.wantToolCall) {
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
}
}

View File

@@ -1,98 +0,0 @@
package parsers
import (
"encoding/json"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// argsComparer provides cmp options for comparing ToolCallFunctionArguments
// It compares by logical equality (same keys with same values) not by order
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
// Convert both to maps and compare
aMap := a.ToMap()
bMap := b.ToMap()
if len(aMap) != len(bMap) {
return false
}
for k, av := range aMap {
bv, ok := bMap[k]
if !ok {
return false
}
// Use JSON encoding for deep comparison of values
aJSON, _ := json.Marshal(av)
bJSON, _ := json.Marshal(bv)
if string(aJSON) != string(bJSON) {
return false
}
}
return true
})
// propsComparer provides cmp options for comparing ToolPropertiesMap
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
aJSON, _ := json.Marshal(a)
bJSON, _ := json.Marshal(b)
return string(aJSON) == string(bJSON)
})
// toolsComparer combines argsComparer and propsComparer for comparing tools
var toolsComparer = cmp.Options{argsComparer, propsComparer}
// toolCallEqual compares two tool calls by comparing their components
// It compares arguments by logical equality (same keys with same values) not by order
func toolCallEqual(a, b api.ToolCall) bool {
if a.ID != b.ID {
return false
}
if a.Function.Index != b.Function.Index {
return false
}
if a.Function.Name != b.Function.Name {
return false
}
// Compare arguments by logical equality using argsComparer logic
aMap := a.Function.Arguments.ToMap()
bMap := b.Function.Arguments.ToMap()
if len(aMap) != len(bMap) {
return false
}
for k, av := range aMap {
bv, ok := bMap[k]
if !ok {
return false
}
aJSON, _ := json.Marshal(av)
bJSON, _ := json.Marshal(bv)
if string(aJSON) != string(bJSON) {
return false
}
}
return true
}
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -94,12 +94,12 @@ You are a helpful assistant.
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -139,9 +139,9 @@ You have the following functions available:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -162,9 +162,9 @@ You have the following functions available:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -186,17 +186,17 @@ You have the following functions available:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "London",
}),
},
},
},
},
@@ -226,12 +226,12 @@ You have the following functions available:
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak!
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak!
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgsOrdered([]orderedArg{
{"config", map[string]any{
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2", "item3"},
"config": map[string]any{
"enabled": true,
"threshold": 0.95,
"tags": []string{"important", "urgent"},
}},
{"items", []any{"item1", "item2", "item3"}},
}),
},
},
},
},
},

View File

@@ -82,9 +82,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -104,9 +104,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -125,9 +125,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -147,17 +147,17 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "London",
}),
},
},
},
},
@@ -214,9 +214,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -235,9 +235,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"data": "test",
}),
},
},
},
},
@@ -281,9 +281,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -305,9 +305,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -355,9 +355,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -379,9 +379,9 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -436,17 +436,17 @@ Second instruction<User>Hello<Assistant></think>`,
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "New York",
}),
},
},
},
},
@@ -489,12 +489,12 @@ Second instruction<User>Hello<Assistant></think>`,
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -535,12 +535,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -578,9 +578,9 @@ Where:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -594,12 +594,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -638,9 +638,9 @@ Where:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
},
@@ -656,12 +656,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -701,9 +701,9 @@ Where:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
}),
},
},
},
},
@@ -724,12 +724,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -770,12 +770,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -787,12 +787,12 @@ Where:
Description: "Perform mathematical calculations",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"expression": {
Type: api.PropertyType{"string"},
Description: "Mathematical expression to evaluate",
},
}),
},
Required: []string{"expression"},
},
},
@@ -834,17 +834,17 @@ Where:
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
}),
},
},
},
{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
Arguments: api.ToolCallFunctionArguments{
"expression": "25 * 4",
}),
},
},
},
},
@@ -860,12 +860,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},
@@ -877,12 +877,12 @@ Where:
Description: "Perform mathematical calculations",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"expression": {
Type: api.PropertyType{"string"},
Description: "Mathematical expression to evaluate",
},
}),
},
Required: []string{"expression"},
},
},
@@ -927,12 +927,12 @@ Where:
Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
},
Required: []string{"location"},
},
},

View File

@@ -136,7 +136,7 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
needsComma := false
// Only include properties:{} if there are actual properties
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
if len(fn.Parameters.Properties) > 0 {
sb.WriteString("properties:{")
r.writeProperties(&sb, fn.Parameters.Properties)
sb.WriteString("}")
@@ -172,16 +172,16 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
return sb.String()
}
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
keys := make([]string, 0, props.Len())
for k := range props.All() {
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
keys := make([]string, 0, len(props))
for k := range props {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, name := range keys {
prop, _ := props.Get(name)
prop := props[name]
if !first {
sb.WriteString(",")
}
@@ -203,15 +203,15 @@ func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
var sb strings.Builder
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
keys := make([]string, 0, tc.Function.Arguments.Len())
for k := range tc.Function.Arguments.All() {
keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
value, _ := tc.Function.Arguments.Get(key)
value := tc.Function.Arguments[key]
if !first {
sb.WriteString(",")
}

View File

@@ -51,9 +51,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -75,9 +75,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -107,9 +107,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -126,7 +126,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
@@ -141,9 +141,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -161,7 +161,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
@@ -176,9 +176,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -195,7 +195,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}),
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
},
},
},
@@ -210,10 +210,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Add numbers",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"a": {Type: api.PropertyType{"number"}},
"b": {Type: api.PropertyType{"number"}},
}),
},
},
},
},
@@ -239,10 +239,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
}),
},
},
},
},
@@ -263,9 +263,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -276,9 +276,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get current time",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
}),
},
},
},
},
@@ -296,13 +296,13 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
},
},
},
@@ -318,9 +318,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -331,9 +331,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get current time",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
}),
},
},
},
},
@@ -351,7 +351,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
@@ -367,9 +367,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
@@ -391,7 +391,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{}),
Properties: map[string]api.ToolProperty{},
},
},
},
@@ -430,7 +430,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "set_flag",
Arguments: testArgs(map[string]any{"enabled": true}),
Arguments: api.ToolCallFunctionArguments{"enabled": true},
},
},
},
@@ -445,9 +445,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Set a flag",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
}),
},
},
},
},
@@ -468,11 +468,11 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"a", "b", "c"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"a": {Type: api.PropertyType{"string"}, Description: "A"},
"b": {Type: api.PropertyType{"string"}, Description: "B"},
"c": {Type: api.PropertyType{"string"}, Description: "C"},
}),
},
},
},
},
@@ -492,9 +492,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
Description: "Test",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
}),
},
},
},
},

View File

@@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
sb.WriteString("\n<parameters>")
if fn.Parameters.Properties != nil {
for paramName, paramFields := range fn.Parameters.Properties.All() {
for paramName, paramFields := range fn.Parameters.Properties {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + paramName + "</name>")
@@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
for _, tc := range toolCalls {
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
for name, value := range tc.Function.Arguments.All() {
for name, value := range tc.Function.Arguments {
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
}
sb.WriteString("</function>\n</tool_call>\n")

View File

@@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}),
},
},
},
},
@@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
},
@@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}),
},
},
},
},
@@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
},
@@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
Arguments: map[string]any{"city": "Paris"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "London"}),
Arguments: map[string]any{"city": "London"},
},
},
},
@@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
msgs: []api.Message{
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
}},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
}},
{Role: "tool", Content: "4", ToolCallID: "call3"},
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
@@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "calculate",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"expression": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
@@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
{Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
},
},
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
@@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_user",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
},
},
},
@@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "create",
Arguments: testArgs(map[string]any{
Arguments: map[string]any{
"data": map[string]any{"nested": "value", "count": 42},
}),
},
}},
},
},
@@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "create",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
},
},
},
@@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
{Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
},
},
{Role: "tool", Content: "Hello"},
@@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "translate",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"text": {Type: api.PropertyType{"string"}},
}),
},
},
},
},

View File

@@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
sb.WriteString("(")
// Get sorted keys for deterministic output
keys := make([]string, 0, tc.Function.Arguments.Len())
for k := range tc.Function.Arguments.All() {
keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments {
keys = append(keys, k)
}
sort.Strings(keys)
@@ -110,8 +110,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
if k > 0 {
sb.WriteString(", ")
}
val, _ := tc.Function.Arguments.Get(key)
value, err := json.Marshal(val)
value, err := json.Marshal(tc.Function.Arguments[key])
if err != nil {
return "", err
}

Some files were not shown because too many files have changed in this diff Show More