mirror of
https://github.com/ollama/ollama.git
synced 2026-01-18 20:39:13 -05:00
Compare commits
1 Commits
parth/decr
...
parth/prom
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d9a37ad62 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
13
.github/workflows/release.yaml
vendored
13
.github/workflows/release.yaml
vendored
@@ -68,7 +68,6 @@ jobs:
|
||||
name: bundles-darwin
|
||||
path: |
|
||||
dist/*.tgz
|
||||
dist/*.tar.zst
|
||||
dist/*.zip
|
||||
dist/*.dmg
|
||||
|
||||
@@ -372,17 +371,13 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
@@ -397,13 +392,13 @@ jobs:
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
*.tar.zst
|
||||
*.tgz
|
||||
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
@@ -536,7 +531,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||
# target architecture.
|
||||
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||
# Single architecture specified
|
||||
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -48,10 +32,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if(APPLE)
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -164,56 +147,14 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,28 +83,6 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -162,21 +140,6 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
37
Dockerfile
37
Dockerfile
@@ -131,39 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -184,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -204,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
||||
@@ -1,778 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -1,953 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
|
||||
159
api/types.go
159
api/types.go
@@ -3,7 +3,6 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/orderedmap"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -229,79 +227,13 @@ type ToolCallFunction struct {
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolCallFunctionArguments holds tool call arguments in insertion order.
|
||||
type ToolCallFunctionArguments struct {
|
||||
om *orderedmap.Map[string, any]
|
||||
}
|
||||
|
||||
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
|
||||
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
|
||||
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return nil, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair, preserving insertion order.
|
||||
func (t *ToolCallFunctionArguments) Set(key string, value any) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of arguments.
|
||||
func (t *ToolCallFunctionArguments) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, any) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
type ToolCallFunctionArguments map[string]any
|
||||
|
||||
func (t *ToolCallFunctionArguments) String() string {
|
||||
if t == nil || t.om == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t.om)
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
@@ -350,78 +282,13 @@ func (pt PropertyType) String() string {
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
// ToolPropertiesMap holds tool properties in insertion order.
|
||||
type ToolPropertiesMap struct {
|
||||
om *orderedmap.Map[string, ToolProperty]
|
||||
}
|
||||
|
||||
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
|
||||
func NewToolPropertiesMap() *ToolPropertiesMap {
|
||||
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
|
||||
}
|
||||
|
||||
// Get retrieves a property by name.
|
||||
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return ToolProperty{}, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a property, preserving insertion order.
|
||||
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of properties.
|
||||
func (t *ToolPropertiesMap) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all properties in insertion order.
|
||||
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, ToolProperty) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
|
||||
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
@@ -470,11 +337,11 @@ func mapToTypeScriptType(jsonType string) string {
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
|
||||
@@ -11,24 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
|
||||
props := NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) ToolCallFunctionArguments {
|
||||
args := NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -327,9 +309,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
@@ -337,9 +319,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
name: "no required",
|
||||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
@@ -357,7 +339,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||
fn := ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: testArgs(map[string]any{"message": "hi"}),
|
||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fn)
|
||||
@@ -547,7 +529,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location details",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"address": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "Street address",
|
||||
@@ -556,7 +538,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -584,22 +566,22 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Event",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"location": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"coordinates": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "GPS coordinates",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -609,13 +591,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
var prop ToolProperty
|
||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare JSON representations since pointer comparison doesn't work
|
||||
expectedJSON, err := json.Marshal(tt.expected)
|
||||
require.NoError(t, err)
|
||||
actualJSON, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||
assert.Equal(t, tt.expected, prop)
|
||||
|
||||
// Round-trip test: marshal and unmarshal again
|
||||
data, err := json.Marshal(prop)
|
||||
@@ -624,10 +600,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
var prop2 ToolProperty
|
||||
err = json.Unmarshal(data, &prop2)
|
||||
require.NoError(t, err)
|
||||
|
||||
prop2JSON, err := json.Marshal(prop2)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
|
||||
assert.Equal(t, tt.expected, prop2)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -643,12 +616,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
@@ -665,7 +638,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: testPropsMap(map[string]ToolProperty{}),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
@@ -678,235 +651,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("zebra", "z")
|
||||
args.Set("apple", "a")
|
||||
args.Set("mango", "m")
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":1,"a":2,"m":3,"b":4}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(original), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("String method returns ordered JSON", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("c", 3)
|
||||
args.Set("a", 1)
|
||||
args.Set("b", 2)
|
||||
|
||||
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("key1", "value1")
|
||||
args.Set("key2", 42)
|
||||
|
||||
v, ok := args.Get("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", v)
|
||||
|
||||
v, ok = args.Get("key2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, v)
|
||||
|
||||
_, ok = args.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
assert.Equal(t, 0, args.Len())
|
||||
|
||||
args.Set("a", 1)
|
||||
assert.Equal(t, 1, args.Len())
|
||||
|
||||
args.Set("b", 2)
|
||||
assert.Equal(t, 2, args.Len())
|
||||
})
|
||||
|
||||
t.Run("empty args marshal to empty object", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("zero value args marshal to empty object", func(t *testing.T) {
|
||||
var args ToolCallFunctionArguments
|
||||
assert.Equal(t, "{}", args.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
|
||||
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(jsonData), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(original), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
|
||||
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
|
||||
|
||||
v, ok := props.Get("name")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The name", v.Description)
|
||||
|
||||
v, ok = props.Get("age")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The age", v.Description)
|
||||
|
||||
_, ok = props.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
assert.Equal(t, 0, props.Len())
|
||||
|
||||
props.Set("a", ToolProperty{})
|
||||
assert.Equal(t, 1, props.Len())
|
||||
|
||||
props.Set("b", ToolProperty{})
|
||||
assert.Equal(t, 2, props.Len())
|
||||
})
|
||||
|
||||
t.Run("nil props marshal to null", func(t *testing.T) {
|
||||
var props *ToolPropertiesMap
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `null`, string(data))
|
||||
})
|
||||
|
||||
t.Run("ToMap returns regular map", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
m := props.ToMap()
|
||||
assert.Equal(t, 2, len(m))
|
||||
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
|
||||
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
|
||||
t.Run("nested objects preserve order", func(t *testing.T) {
|
||||
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Outer keys should be in order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"outer", "simple"}, keys)
|
||||
})
|
||||
|
||||
t.Run("arrays as values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("items", []string{"a", "b", "c"})
|
||||
args.Set("numbers", []int{1, 2, 3})
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
|
||||
t.Run("nested properties preserve order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
|
||||
nestedProps := NewToolPropertiesMap()
|
||||
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
|
||||
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
props.Set("outer", ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Properties: nestedProps,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both outer and inner should preserve order
|
||||
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -147,7 +147,6 @@ export const highlighterPromise = createHighlighter({
|
||||
"c",
|
||||
"cpp",
|
||||
"sql",
|
||||
"swift",
|
||||
"yaml",
|
||||
"markdown",
|
||||
],
|
||||
|
||||
@@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
for _, toolCall := range res.Message.ToolCalls {
|
||||
// continues loop as tools were executed
|
||||
toolsExecuted = true
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
errContent := fmt.Sprintf("Error: %v", err)
|
||||
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
||||
@@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||
|
||||
tool.Function.Parameters.Type = "object"
|
||||
tool.Function.Parameters.Required = []string{}
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
|
||||
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
||||
|
||||
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
|
||||
for propName, propDef := range props {
|
||||
if propMap, ok := propDef.(map[string]any); ok {
|
||||
@@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
||||
Description: getStringFromMap(propMap, "description", ""),
|
||||
}
|
||||
tool.Function.Parameters.Properties.Set(propName, prop)
|
||||
tool.Function.Parameters.Properties[propName] = prop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
39
cmd/cmd.go
39
cmd/cmd.go
@@ -45,9 +45,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -98,11 +95,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
@@ -464,7 +456,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -526,19 +517,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -565,11 +543,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@@ -673,11 +646,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -1785,12 +1754,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -1547,79 +1547,6 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -40,7 +40,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
|
||||
156
cmd/prompt-rendering/README.md
Normal file
156
cmd/prompt-rendering/README.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# HuggingFace Prompt Renderer MCP Server
|
||||
|
||||
Model Context Protocol (MCP) server for rendering conversation messages into
|
||||
model-specific prompt strings using HuggingFace tokenizer chat templates.
|
||||
|
||||
## Requirements
|
||||
|
||||
- [uv](https://docs.astral.sh/uv/) - Fast Python package installer
|
||||
|
||||
## Usage
|
||||
|
||||
### MCP Server Mode
|
||||
|
||||
Run the MCP server over stdio for use with MCP clients:
|
||||
|
||||
```bash
|
||||
uv run cmd/prompt-rendering/server.py --mcp
|
||||
```
|
||||
|
||||
Add to your MCP client configuration (e.g., for Claude Desktop):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"huggingface-prompt-renderer": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--directory",
|
||||
"<path-to-ollama-repo>",
|
||||
"cmd/prompt-rendering/server.py",
|
||||
"--mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### FastAPI Server Mode
|
||||
|
||||
Start a FastAPI server for manual HTTP testing:
|
||||
|
||||
```bash
|
||||
# Start on default port 8000
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 8000
|
||||
|
||||
# Start on custom port
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 9000
|
||||
```
|
||||
|
||||
#### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/generate-prompt` | Generate prompt from messages |
|
||||
| GET | `/health` | Health check |
|
||||
|
||||
### Test with curl
|
||||
|
||||
```bash
|
||||
# Basic user message
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
|
||||
# With tools
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather?"}
|
||||
],
|
||||
"model": "Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}'
|
||||
|
||||
# With tool calls
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in SF?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"}
|
||||
}
|
||||
}]
|
||||
},
|
||||
{"role": "tool", "content": "{\"temperature\": 68}", "tool_call_id": "call_1"}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
## Supported Message Formats
|
||||
|
||||
The server supports multiple message formats:
|
||||
|
||||
| Format | Description |
|
||||
|--------|-------------|
|
||||
| OpenAI | Standard `role`, `content`, `tool_calls`, `tool_call_id` |
|
||||
| OLMo | Adds `functions` and `function_calls` fields |
|
||||
| DeepSeek | Tool call arguments must be JSON strings |
|
||||
|
||||
## Tool Support
|
||||
|
||||
| Setting | Description |
|
||||
|---------|-------------|
|
||||
| `inject_tools_as_functions=true` | Injects tools into system message as `functions` key (OLMo-style) |
|
||||
| `inject_tools_as_functions=false` | Passes tools separately to `apply_chat_template` (standard transformers) |
|
||||
|
||||
## Models
|
||||
|
||||
The server uses HuggingFace's `transformers` library and supports any model
|
||||
with a chat template. Default: `Qwen/Qwen3-Coder-480B-A35B-Instruct`
|
||||
|
||||
## Dependencies
|
||||
|
||||
The script uses PEP 723 inline dependency metadata. When run with `uv`,
|
||||
dependencies are automatically installed into an isolated environment:
|
||||
|
||||
- `fastapi` - Web framework
|
||||
- `uvicorn` - ASGI server
|
||||
- `transformers` - HuggingFace tokenizer
|
||||
- `jinja2` - Template engine
|
||||
- `mcp` - Model Context Protocol
|
||||
311
cmd/prompt-rendering/server.py
Normal file
311
cmd/prompt-rendering/server.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "fastapi",
|
||||
# "uvicorn",
|
||||
# "transformers",
|
||||
# "jinja2",
|
||||
# "mcp",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
HuggingFace Prompt Renderer MCP Server
|
||||
|
||||
Model Context Protocol (MCP) server for rendering conversation messages into
|
||||
model-specific prompt strings using HuggingFace tokenizer chat templates.
|
||||
|
||||
Usage:
|
||||
# Run MCP server over stdio
|
||||
uv run cmd/prompt-rendering/server.py --mcp
|
||||
|
||||
# Start FastAPI server for manual testing
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 8000
|
||||
|
||||
# Test with curl
|
||||
curl -X POST http://localhost:8000/generate-prompt \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{"messages": [{"role": "user", "content": "Hello!"}]}'
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
except Exception:
|
||||
FastMCP = None
|
||||
|
||||
# Cache for tokenizers to avoid reloading
|
||||
_tokenizer_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
functions: Optional[str] = None # For OLMo-style function passing
|
||||
function_calls: Optional[str] = None # For OLMo-style function call results
|
||||
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
# Whether to inject tools into system message as 'functions' key (for OLMo-style templates)
|
||||
inject_tools_as_functions: Optional[bool] = True
|
||||
|
||||
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
prompt: str
|
||||
model: str
|
||||
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> Any:
|
||||
"""Get or create tokenizer for the given model."""
|
||||
if model_name not in _tokenizer_cache:
|
||||
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(
|
||||
model_name, trust_remote_code=True
|
||||
)
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
|
||||
def is_deepseek_model(model_name: str) -> bool:
|
||||
"""Check if this is a DeepSeek model."""
|
||||
return "deepseek" in model_name.lower()
|
||||
|
||||
|
||||
def normalize_messages(
|
||||
raw_messages: List[Any],
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
inject_tools_as_functions: bool,
|
||||
model: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Normalize messages for different chat template formats."""
|
||||
messages: List[Dict[str, Any]] = []
|
||||
tools_json = json.dumps(tools) if tools else None
|
||||
is_deepseek = is_deepseek_model(model)
|
||||
|
||||
for msg in raw_messages:
|
||||
message = msg if isinstance(msg, Message) else Message(**msg)
|
||||
message_dict: Dict[str, Any] = {"role": message.role, "content": None}
|
||||
|
||||
if message.content is not None:
|
||||
message_dict["content"] = message.content
|
||||
|
||||
# Handle explicit functions field (OLMo-style)
|
||||
if message.functions is not None:
|
||||
message_dict["functions"] = message.functions
|
||||
# Inject tools into system message as 'functions' (for OLMo templates)
|
||||
elif inject_tools_as_functions and message.role == "system" and tools_json:
|
||||
message_dict["functions"] = tools_json
|
||||
|
||||
# Handle explicit function_calls field (OLMo-style)
|
||||
if message.function_calls is not None:
|
||||
message_dict["function_calls"] = message.function_calls
|
||||
# Convert tool_calls for templates
|
||||
elif message.tool_calls is not None:
|
||||
if is_deepseek:
|
||||
# DeepSeek format: arguments must be a JSON string
|
||||
tool_calls = []
|
||||
for tool_call in message.tool_calls:
|
||||
tc = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.dumps(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], dict)
|
||||
else tool_call["function"]["arguments"],
|
||||
},
|
||||
}
|
||||
tool_calls.append(tc)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
elif inject_tools_as_functions:
|
||||
# Convert to OLMo function_calls format
|
||||
message_dict["function_calls"] = json.dumps(message.tool_calls)
|
||||
else:
|
||||
# Standard transformers format
|
||||
tool_calls = []
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call_copy = tool_call.copy()
|
||||
if (
|
||||
"function" in tool_call_copy
|
||||
and "arguments" in tool_call_copy["function"]
|
||||
):
|
||||
try:
|
||||
tool_call_copy["function"]["arguments"] = json.loads(
|
||||
tool_call_copy["function"]["arguments"]
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
tool_calls.append(tool_call_copy)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
|
||||
if message.tool_call_id is not None:
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
|
||||
messages.append(message_dict)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_prompt(
|
||||
raw_messages: List[Any],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
inject_tools_as_functions: bool,
|
||||
) -> str:
|
||||
"""Build prompt from messages using the model's chat template."""
|
||||
messages = normalize_messages(
|
||||
raw_messages=raw_messages,
|
||||
tools=tools,
|
||||
inject_tools_as_functions=inject_tools_as_functions,
|
||||
model=model,
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
|
||||
# For OLMo-style templates, don't pass tools separately (they're in messages)
|
||||
if tools and not inject_tools_as_functions:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||
async def generate_prompt(request: GeneratePromptRequest):
|
||||
"""
|
||||
Generate a prompt from messages using the specified model's chat template.
|
||||
Optionally includes tool definitions if provided.
|
||||
"""
|
||||
try:
|
||||
prompt = build_prompt(
|
||||
raw_messages=request.messages,
|
||||
model=request.model,
|
||||
tools=request.tools,
|
||||
inject_tools_as_functions=request.inject_tools_as_functions,
|
||||
)
|
||||
return GeneratePromptResponse(prompt=prompt, model=request.model)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate prompt: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if FastMCP is not None:
|
||||
mcp = FastMCP("huggingface-prompt-renderer")
|
||||
|
||||
@mcp.tool()
|
||||
def generate_prompt_tool(
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str = "Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
inject_tools_as_functions: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Render conversation messages into a model-specific prompt string using HuggingFace tokenizer chat templates.
|
||||
|
||||
This tool takes a list of message objects and applies the target model's chat template to produce
|
||||
the exact prompt string that would be fed to the model. It handles various message formats including
|
||||
standard OpenAI-style, OLMo-style (functions/function_calls), and DeepSeek-specific formatting.
|
||||
|
||||
Use this tool to:
|
||||
- Verify that a model's chat template correctly formats your conversation
|
||||
- Test edge cases: tool calling, tool responses, interleaved thinking and tool calls, multiple tools in single response
|
||||
- Compare prompt output across different models to understand template differences
|
||||
- Debug issues with message formatting that cause unexpected model behavior
|
||||
|
||||
Message format supports:
|
||||
- role: "user", "assistant", "system", "tool"
|
||||
- content: string content of the message
|
||||
- tool_calls: list of tool call objects (OpenAI format: {type, function: {name, arguments}})
|
||||
- tool_call_id: for tool role messages, references the call being responded to
|
||||
- functions: optional field for OLMo-style tool definitions
|
||||
- function_calls: optional field for OLMo-style tool call results
|
||||
|
||||
Parameters:
|
||||
- messages: List of message dictionaries forming the conversation
|
||||
- model: HuggingFace model identifier (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
|
||||
- tools: Optional list of tool/function definitions for function calling models
|
||||
- inject_tools_as_functions: If True, injects tools into system message as 'functions' key (OLMo-style). If False, passes tools separately to apply_chat_template.
|
||||
|
||||
Returns: Dictionary with 'prompt' (rendered string) and 'model' keys.
|
||||
|
||||
Recommended test cases:
|
||||
1. Simple conversation: user -> assistant
|
||||
2. Tool calling: user -> assistant with tool_call -> tool response -> assistant
|
||||
3. Multiple tool calls in one assistant message
|
||||
4. Multiple tool responses interleaved with assistant reasoning
|
||||
5. Nested tool calls (assistant calls tool, uses result to call another)
|
||||
6. System message with tool definitions
|
||||
7. Empty or None content in messages
|
||||
8. Very long messages to test truncation handling
|
||||
"""
|
||||
prompt = build_prompt(
|
||||
raw_messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
inject_tools_as_functions=inject_tools_as_functions,
|
||||
)
|
||||
return {"prompt": prompt, "model": model}
|
||||
else:
|
||||
mcp = None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace Prompt Renderer MCP Server",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mcp", action="store_true", help="Run MCP server over stdio"
|
||||
)
|
||||
parser.add_argument("--host", default="0.0.0.0", help="FastAPI host")
|
||||
parser.add_argument("--port", type=int, default=8000, help="FastAPI port")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mcp:
|
||||
if mcp is None:
|
||||
raise RuntimeError("MCP server requested but mcp is not installed.")
|
||||
mcp.run()
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,14 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -21,13 +18,8 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
@@ -41,94 +33,8 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() KV {
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := KV{
|
||||
kv := ggml.KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -206,7 +107,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ofs.Config) KV
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -214,7 +115,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
return errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -7,7 +9,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,6 +3,8 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -53,7 +55,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -19,13 +18,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
@@ -34,7 +34,7 @@ type olmoModel struct {
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
for k, v := range kv {
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV KV
|
||||
BaseKV map[string]any
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"name": "get_temperature",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_weather"
|
||||
"tool_name": "get_temperature",
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -277,8 +277,6 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
|
||||
|
||||
#### Supported features
|
||||
|
||||
@@ -36,6 +36,7 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
|
||||
@@ -32,9 +32,7 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
"options": ["copy"]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -54,9 +52,7 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
"languages": ["curl"]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -101,7 +97,6 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -144,8 +139,7 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
"/api/openai-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
3
docs/troubleshooting.md
Normal file
3
docs/troubleshooting.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
@@ -1,7 +1,5 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -13,8 +11,4 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -28,7 +28,6 @@ require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
@@ -37,8 +36,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
@@ -48,7 +45,6 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
9
go.sum
9
go.sum
@@ -14,11 +14,7 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -127,7 +123,6 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
@@ -148,8 +143,6 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
@@ -214,8 +207,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
|
||||
@@ -11,15 +11,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
@@ -66,12 +57,12 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"iter"
|
||||
|
||||
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||
)
|
||||
|
||||
// Map is a generic ordered map that maintains insertion order.
|
||||
type Map[K comparable, V any] struct {
|
||||
om *orderedmap.OrderedMap[K, V]
|
||||
}
|
||||
|
||||
// New creates a new empty ordered map.
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
om: orderedmap.New[K, V](),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
if m == nil || m.om == nil {
|
||||
var zero V
|
||||
return zero, false
|
||||
}
|
||||
return m.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||
// but its position in the iteration order is preserved. If the key is new,
|
||||
// it is appended to the end.
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.om == nil {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
}
|
||||
m.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of entries.
|
||||
func (m *Map[K, V]) Len() int {
|
||||
if m == nil || m.om == nil {
|
||||
return 0
|
||||
}
|
||||
return m.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
if m == nil || m.om == nil {
|
||||
return
|
||||
}
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
if !yield(pair.Key, pair.Value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToMap converts to a regular Go map.
|
||||
// Note: The resulting map does not preserve order.
|
||||
func (m *Map[K, V]) ToMap() map[K]V {
|
||||
if m == nil || m.om == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[K]V, m.om.Len())
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
result[pair.Key] = pair.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||
if m == nil || m.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.om)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||
// order of keys in the JSON input.
|
||||
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
return json.Unmarshal(data, &m.om)
|
||||
}
|
||||
@@ -1,348 +0,0 @@
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMap_BasicOperations(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Test empty map
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||
}
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on empty map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value, got %d", v)
|
||||
}
|
||||
|
||||
// Test Set and Get
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("b")
|
||||
if !ok || v != 2 {
|
||||
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("c")
|
||||
if !ok || v != 3 {
|
||||
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
// Test updating existing key preserves position
|
||||
m.Set("a", 10)
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 10 {
|
||||
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
m.Set("b", 4)
|
||||
|
||||
// Verify iteration order matches insertion order
|
||||
var keys []string
|
||||
var values []int
|
||||
for k, v := range m.All() {
|
||||
keys = append(keys, k)
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
expectedKeys := []string{"z", "a", "m", "b"}
|
||||
expectedValues := []int{1, 2, 3, 4}
|
||||
|
||||
if !slices.Equal(keys, expectedKeys) {
|
||||
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||
}
|
||||
if !slices.Equal(values, expectedValues) {
|
||||
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
m.Set("first", 1)
|
||||
m.Set("second", 2)
|
||||
m.Set("third", 3)
|
||||
|
||||
// Update middle element
|
||||
m.Set("second", 20)
|
||||
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Order should still be first, second, third
|
||||
expected := []string{"first", "second", "third"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON should preserve insertion order, not alphabetical
|
||||
expected := `{"z":1,"a":2,"m":3}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||
// JSON with non-alphabetical key order
|
||||
jsonData := `{"z":1,"a":2,"m":3}`
|
||||
|
||||
m := New[string, int]()
|
||||
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
expected := []string{"z", "a", "m"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||
// Test that unmarshal -> marshal produces identical JSON
|
||||
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||
|
||||
m := New[string, string]()
|
||||
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != original {
|
||||
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_ToMap(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
|
||||
regular := m.ToMap()
|
||||
|
||||
if len(regular) != 2 {
|
||||
t.Errorf("expected len 2, got %d", len(regular))
|
||||
}
|
||||
if regular["a"] != 1 {
|
||||
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||
}
|
||||
if regular["b"] != 2 {
|
||||
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NilSafety(t *testing.T) {
|
||||
var m *Map[string, int]
|
||||
|
||||
// All operations should be safe on nil
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on nil map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value from nil map, got %d", v)
|
||||
}
|
||||
|
||||
// Set on nil is a no-op
|
||||
m.Set("a", 1)
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||
}
|
||||
|
||||
// All returns empty iterator
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) != 0 {
|
||||
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||
}
|
||||
|
||||
// ToMap returns nil
|
||||
if m.ToMap() != nil {
|
||||
t.Error("expected ToMap to return nil on nil map")
|
||||
}
|
||||
|
||||
// MarshalJSON returns null
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("expected {}, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NestedValues(t *testing.T) {
|
||||
m := New[string, any]()
|
||||
m.Set("string", "hello")
|
||||
m.Set("number", 42)
|
||||
m.Set("bool", true)
|
||||
m.Set("nested", map[string]int{"x": 1})
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
m.Set("d", 4)
|
||||
|
||||
// Collect only first 2
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
if len(keys) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
expected := []string{"a", "b"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_IntegerKeys(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(3, "three")
|
||||
m.Set(1, "one")
|
||||
m.Set(2, "two")
|
||||
|
||||
var keys []int
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Should preserve insertion order, not numerical order
|
||||
expected := []int{3, 1, 2}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("existing", 999)
|
||||
|
||||
// Unmarshal should replace contents
|
||||
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
_, ok := m.Get("existing")
|
||||
if ok {
|
||||
t.Error("existing key should be gone after unmarshal")
|
||||
}
|
||||
|
||||
v, ok := m.Get("new")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Create many keys in specific order
|
||||
keys := make([]string, 100)
|
||||
for i := range 100 {
|
||||
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||
if i >= 26 {
|
||||
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||
}
|
||||
}
|
||||
|
||||
for i, k := range keys {
|
||||
m.Set(k, i)
|
||||
}
|
||||
|
||||
// Verify order preserved
|
||||
var resultKeys []string
|
||||
for k := range m.All() {
|
||||
resultKeys = append(resultKeys, k)
|
||||
}
|
||||
|
||||
if !slices.Equal(keys, resultKeys) {
|
||||
t.Error("large map should preserve insertion order")
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,607 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err anthropic.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
stream := true
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "basic message",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with system prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are helpful.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop": []string{"\n", "END"},
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &stream,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testProps(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tool result",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with thinking enabled",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "model is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing max_tokens error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "max_tokens is required and must be positive",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing messages error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "messages is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_use missing id error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "name": "test"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "tool_use block missing required 'id' field",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Type != "" {
|
||||
// Expect error
|
||||
if resp.Code == http.StatusOK {
|
||||
t.Fatalf("expected error response, got 200 OK")
|
||||
}
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
if errResp.Type != tc.err.Type {
|
||||
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tc.err.Error.Type {
|
||||
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tc.err.Error.Message {
|
||||
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("request was not captured")
|
||||
}
|
||||
|
||||
// Compare relevant fields
|
||||
if capturedRequest.Model != tc.req.Model {
|
||||
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
||||
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
||||
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
||||
if *tc.req.Stream != *capturedRequest.Stream {
|
||||
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.req.Think != nil {
|
||||
if capturedRequest.Think == nil {
|
||||
t.Error("expected Think to be set")
|
||||
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
||||
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("streaming sets correct headers", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Check headers were set
|
||||
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
||||
}
|
||||
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate Ollama response
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result anthropic.MessagesResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
||||
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
||||
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
errorPayload any
|
||||
wantErrorType string
|
||||
wantMessage string
|
||||
}{
|
||||
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
||||
{
|
||||
name: "404 with gin.H error (model not found)",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "400 with gin.H error (bad request)",
|
||||
statusCode: http.StatusBadRequest,
|
||||
errorPayload: gin.H{"error": "model is required"},
|
||||
wantErrorType: "invalid_request_error",
|
||||
wantMessage: "model is required",
|
||||
},
|
||||
{
|
||||
name: "500 with gin.H error (internal error)",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
errorPayload: gin.H{"error": "something went wrong"},
|
||||
wantErrorType: "api_error",
|
||||
wantMessage: "something went wrong",
|
||||
},
|
||||
{
|
||||
name: "404 with api.StatusError",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: api.StatusError{
|
||||
StatusCode: http.StatusNotFound,
|
||||
ErrorMessage: "model not found via StatusError",
|
||||
},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model not found via StatusError",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate what routes.go does - set status and write error JSON
|
||||
data, _ := json.Marshal(tt.errorPayload)
|
||||
c.Writer.WriteHeader(tt.statusCode)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.statusCode {
|
||||
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tt.wantErrorType {
|
||||
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tt.wantMessage {
|
||||
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
@@ -19,40 +19,6 @@ import (
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
|
||||
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
@@ -255,10 +221,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -295,10 +261,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -334,10 +300,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -374,10 +340,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -414,10 +380,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -460,10 +426,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -528,7 +494,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
@@ -537,7 +503,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -592,7 +558,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" {
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
|
||||
@@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true, "threshold": 0.95},
|
||||
"count": 42.0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -238,7 +238,7 @@ This is line 3</think>Final response here.`,
|
||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"arg": "value",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
||||
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
|
||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true},
|
||||
"count": 42.0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "no_args_tool",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
"units": "metric",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "complex_tool",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"nested": map[string]any{
|
||||
"deep": map[string]any{
|
||||
"value": 123.0,
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
@@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []interface{}{"item1", "item2"},
|
||||
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "北京天气",
|
||||
"language": "中文",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute_command",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) {
|
||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"x": float64(42),
|
||||
"y": float64(24),
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) {
|
||||
|
||||
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tools, returnedTools); diff != "" {
|
||||
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []interface{}{"a", "b"},
|
||||
"config": map[string]interface{}{"enabled": true},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"城市": "北京",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"arg": "value",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -166,7 +166,7 @@ func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error
|
||||
|
||||
// parseArguments parses the key:value,key:value format
|
||||
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args := make(api.ToolCallFunctionArguments)
|
||||
if argsStr == "" {
|
||||
return args
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctio
|
||||
value := part[colonIdx+1:]
|
||||
|
||||
// Parse the value
|
||||
args.Set(key, p.parseValue(value))
|
||||
args[key] = p.parseValue(value)
|
||||
}
|
||||
|
||||
return args
|
||||
|
||||
@@ -3,7 +3,6 @@ package parsers
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -37,9 +36,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -48,7 +47,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -67,7 +66,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -85,7 +84,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}),
|
||||
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -103,7 +102,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}),
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -125,13 +124,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -153,7 +152,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}),
|
||||
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -174,9 +173,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"data": map[string]any{"name": "test", "value": int64(42)},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -199,7 +198,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -225,7 +224,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": 3.14}),
|
||||
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -243,7 +242,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -262,7 +261,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "greet",
|
||||
Arguments: testArgs(map[string]any{"name": "日本語"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -282,11 +281,11 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "test",
|
||||
"limit": int64(10),
|
||||
"offset": int64(0),
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -309,14 +308,14 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"config": map[string]any{
|
||||
"settings": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -346,13 +345,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -373,13 +372,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "first",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "second",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -412,9 +411,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedText, allContent)
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
assert.Equal(t, tt.expectedCalls, allCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: args,
|
||||
Arguments: api.ToolCallFunctionArguments(data),
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
|
||||
@@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
||||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
@@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
||||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,11 +244,9 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any
|
||||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
Arguments: map[string]any{"city": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
Arguments: map[string]any{"city": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
Arguments: map[string]any{"city": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{"query": "test"}),
|
||||
Arguments: map[string]any{"query": "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
@@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
Arguments: map[string]any{"value": "42"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
Arguments: map[string]any{"city": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{"query": "test query"}),
|
||||
Arguments: map[string]any{"query": "test query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
Arguments: map[string]any{"city": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
Arguments: map[string]any{"city": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
Arguments: map[string]any{"name": ""},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
p := &Nemotron3NanoParser{}
|
||||
returnedTools := p.Init(tools, nil, nil)
|
||||
|
||||
if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" {
|
||||
if diff := cmp.Diff(returnedTools, tools); diff != "" {
|
||||
t.Errorf("tools mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(calls, expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||
|
||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||
func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||
args := make(map[string]any)
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return args, nil
|
||||
@@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
|
||||
// Find the first = sign
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx == -1 {
|
||||
return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part)
|
||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part[:eqIdx])
|
||||
@@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
|
||||
|
||||
value, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
}
|
||||
|
||||
args.Set(key, value)
|
||||
args[key] = value
|
||||
}
|
||||
|
||||
return args, nil
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
"date": "2024-01-15",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -70,13 +70,13 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -88,7 +88,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temperature",
|
||||
Arguments: testArgs(map[string]any{"value": int64(72)}),
|
||||
Arguments: map[string]any{"value": int64(72)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -100,7 +100,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_price",
|
||||
Arguments: testArgs(map[string]any{"amount": 19.99}),
|
||||
Arguments: map[string]any{"amount": 19.99},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -112,7 +112,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle_setting",
|
||||
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||
Arguments: map[string]any{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -124,7 +124,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "clear_value",
|
||||
Arguments: testArgs(map[string]any{"field": nil}),
|
||||
Arguments: map[string]any{"field": nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -136,7 +136,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_items",
|
||||
Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}),
|
||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -148,12 +148,12 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update_config",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"settings": map[string]any{
|
||||
"theme": "dark",
|
||||
"fontSize": int64(14),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -165,7 +165,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_request",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"data": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John",
|
||||
@@ -173,7 +173,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
},
|
||||
"active": true,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -185,7 +185,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_time",
|
||||
Arguments: testArgs(map[string]any{}),
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -197,7 +197,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{"query": "hello world"}),
|
||||
Arguments: map[string]any{"query": "hello world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -209,7 +209,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{"query": `say "hello"`}),
|
||||
Arguments: map[string]any{"query": `say "hello"`},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -221,11 +221,11 @@ get_weather(location="New York")</function_calls>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_user",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
"active": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -257,7 +257,7 @@ get_weather(location="New York")</function_calls>`,
|
||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{}),
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "send_email",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"to": "user@example.com",
|
||||
"subject": "Hello",
|
||||
"body": "Test message",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -407,13 +407,13 @@ get_time(timezone="PST")`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{"timezone": "PST"}),
|
||||
Arguments: map[string]any{"timezone": "PST"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -437,7 +437,7 @@ get_time(timezone="PST")`,
|
||||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
for _, parameter := range functionCall.Parameters {
|
||||
// Look up the parameter type if we found the tool
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
@@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType))
|
||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||
t.Function.Parameters.Type = "object"
|
||||
t.Function.Parameters.Properties = testPropsMap(props)
|
||||
t.Function.Parameters.Properties = props
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -369,10 +369,10 @@ celsius
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_temperature",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -390,10 +390,10 @@ celsius
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -415,10 +415,10 @@ San Francisco
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -449,12 +449,12 @@ true
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"x": 3.14,
|
||||
"y": 42,
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -470,9 +470,9 @@ ls && echo "done"
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -487,9 +487,9 @@ ls && echo "a > b and a < b"
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا
|
||||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco, CA",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco, CA",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments
|
||||
// It compares by logical equality (same keys with same values) not by order
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
// Convert both to maps and compare
|
||||
aMap := a.ToMap()
|
||||
bMap := b.ToMap()
|
||||
if len(aMap) != len(bMap) {
|
||||
return false
|
||||
}
|
||||
for k, av := range aMap {
|
||||
bv, ok := bMap[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// Use JSON encoding for deep comparison of values
|
||||
aJSON, _ := json.Marshal(av)
|
||||
bJSON, _ := json.Marshal(bv)
|
||||
if string(aJSON) != string(bJSON) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// propsComparer provides cmp options for comparing ToolPropertiesMap
|
||||
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
aJSON, _ := json.Marshal(a)
|
||||
bJSON, _ := json.Marshal(b)
|
||||
return string(aJSON) == string(bJSON)
|
||||
})
|
||||
|
||||
// toolsComparer combines argsComparer and propsComparer for comparing tools
|
||||
var toolsComparer = cmp.Options{argsComparer, propsComparer}
|
||||
|
||||
// toolCallEqual compares two tool calls by comparing their components
|
||||
// It compares arguments by logical equality (same keys with same values) not by order
|
||||
func toolCallEqual(a, b api.ToolCall) bool {
|
||||
if a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
if a.Function.Index != b.Function.Index {
|
||||
return false
|
||||
}
|
||||
if a.Function.Name != b.Function.Name {
|
||||
return false
|
||||
}
|
||||
// Compare arguments by logical equality using argsComparer logic
|
||||
aMap := a.Function.Arguments.ToMap()
|
||||
bMap := b.Function.Arguments.ToMap()
|
||||
if len(aMap) != len(bMap) {
|
||||
return false
|
||||
}
|
||||
for k, av := range aMap {
|
||||
bv, ok := bMap[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
aJSON, _ := json.Marshal(av)
|
||||
bJSON, _ := json.Marshal(bv)
|
||||
if string(aJSON) != string(bJSON) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -94,12 +94,12 @@ You are a helpful assistant.
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -139,9 +139,9 @@ You have the following functions available:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -162,9 +162,9 @@ You have the following functions available:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -186,17 +186,17 @@ You have the following functions available:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -226,12 +226,12 @@ You have the following functions available:
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak!
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak!
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{"config", map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2", "item3"},
|
||||
"config": map[string]any{
|
||||
"enabled": true,
|
||||
"threshold": 0.95,
|
||||
"tags": []string{"important", "urgent"},
|
||||
}},
|
||||
{"items", []any{"item1", "item2", "item3"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -82,9 +82,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -104,9 +104,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -125,9 +125,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -147,17 +147,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -214,9 +214,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -235,9 +235,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"data": "test",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -281,9 +281,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -305,9 +305,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -355,9 +355,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -379,9 +379,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -436,17 +436,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "New York",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -489,12 +489,12 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -535,12 +535,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -578,9 +578,9 @@ Where:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -594,12 +594,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -638,9 +638,9 @@ Where:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -656,12 +656,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -701,9 +701,9 @@ Where:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -724,12 +724,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -770,12 +770,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -787,12 +787,12 @@ Where:
|
||||
Description: "Perform mathematical calculations",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Mathematical expression to evaluate",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"expression"},
|
||||
},
|
||||
},
|
||||
@@ -834,17 +834,17 @@ Where:
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"expression": "25 * 4",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -860,12 +860,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
@@ -877,12 +877,12 @@ Where:
|
||||
Description: "Perform mathematical calculations",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Mathematical expression to evaluate",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"expression"},
|
||||
},
|
||||
},
|
||||
@@ -927,12 +927,12 @@ Where:
|
||||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -136,7 +136,7 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||
needsComma := false
|
||||
|
||||
// Only include properties:{} if there are actual properties
|
||||
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||
if len(fn.Parameters.Properties) > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
@@ -172,16 +172,16 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||
keys := make([]string, 0, props.Len())
|
||||
for k := range props.All() {
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
||||
keys := make([]string, 0, len(props))
|
||||
for k := range props {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop, _ := props.Get(name)
|
||||
prop := props[name]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
@@ -203,15 +203,15 @@ func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
value := tc.Function.Arguments[key]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
|
||||
@@ -51,9 +51,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -75,9 +75,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -107,9 +107,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -126,7 +126,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -141,9 +141,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -161,7 +161,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -176,9 +176,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -195,7 +195,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}),
|
||||
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -210,10 +210,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Add numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"number"}},
|
||||
"b": {Type: api.PropertyType{"number"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -239,10 +239,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
||||
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -263,9 +263,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -276,9 +276,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -296,13 +296,13 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -318,9 +318,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -331,9 +331,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -351,7 +351,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -367,9 +367,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -391,7 +391,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{}),
|
||||
Properties: map[string]api.ToolProperty{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -430,7 +430,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -445,9 +445,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Set a flag",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -468,11 +468,11 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"a", "b", "c"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
||||
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
||||
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -492,9 +492,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||
|
||||
sb.WriteString("\n<parameters>")
|
||||
if fn.Parameters.Properties != nil {
|
||||
for paramName, paramFields := range fn.Parameters.Properties.All() {
|
||||
for paramName, paramFields := range fn.Parameters.Properties {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + paramName + "</name>")
|
||||
|
||||
@@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
|
||||
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
||||
for _, tc := range toolCalls {
|
||||
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
|
||||
for name, value := range tc.Function.Arguments.All() {
|
||||
for name, value := range tc.Function.Arguments {
|
||||
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
|
||||
}
|
||||
sb.WriteString("</function>\n</tool_call>\n")
|
||||
|
||||
@@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||
Arguments: map[string]any{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
||||
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
||||
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
||||
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
|
||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
||||
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
||||
@@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "calculate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"expression": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
|
||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
||||
@@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "get_user",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
|
||||
Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"data": map[string]any{"nested": "value", "count": 42},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
@@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "create",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
|
||||
Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
|
||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Hello"},
|
||||
@@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
Name: "translate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"text": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||
sb.WriteString("(")
|
||||
|
||||
// Get sorted keys for deterministic output
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
@@ -110,8 +110,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||
if k > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
val, _ := tc.Function.Arguments.Get(key)
|
||||
value, err := json.Marshal(val)
|
||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -80,9 +80,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -108,9 +108,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -126,9 +126,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -172,14 +172,14 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -194,9 +194,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -227,10 +227,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{"from", "SFO"},
|
||||
{"to", "NYC"},
|
||||
}),
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -243,10 +243,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
Name: "book_flight",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{"from", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||
{"to", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||
}),
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"from": {Type: api.PropertyType{"string"}},
|
||||
"to": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
}
|
||||
sb.WriteString("\n<parameters>")
|
||||
|
||||
for name, prop := range tool.Function.Parameters.Properties.All() {
|
||||
for name, prop := range tool.Function.Parameters.Properties {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + name + "</name>")
|
||||
|
||||
@@ -147,7 +147,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||
for name, value := range toolCall.Function.Arguments.All() {
|
||||
for name, value := range toolCall.Function.Arguments {
|
||||
valueStr := formatToolCallArgument(value)
|
||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||
}
|
||||
|
||||
@@ -39,9 +39,9 @@ Hello, how are you?<|im_end|>
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -55,7 +55,7 @@ Hello, how are you?<|im_end|>
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Required: []string{"unit"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||
// TODO(drifkin): add multiple params back once we have predictable
|
||||
// order via some sort of ordered map type (see
|
||||
@@ -63,7 +63,7 @@ Hello, how are you?<|im_end|>
|
||||
/*
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||
*/
|
||||
}),
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
@@ -140,19 +140,19 @@ That sounds nice! What about New York?<|im_end|>
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: testArgs(map[string]any{"number": "1"})}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: testArgs(map[string]any{"number": "2"})}},
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||
})}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
}}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||
})}}},
|
||||
}}}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
@@ -259,9 +259,9 @@ I'll tell you something interesting about cats`,
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: map[string]any{
|
||||
"payload": map[string]any{"foo": "bar"},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||
|
||||
@@ -337,7 +337,7 @@ Let me analyze this image.`,
|
||||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||
@@ -367,8 +367,8 @@ Thanks!<|im_end|>
|
||||
Role: "assistant",
|
||||
Content: "before",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "add", Arguments: testArgsOrdered([]orderedArg{{"a", 2}, {"b", 3}})}},
|
||||
{Function: api.ToolCallFunction{Name: "mul", Arguments: testArgsOrdered([]orderedArg{{"x", 4}, {"y", 5}})}},
|
||||
{Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}},
|
||||
{Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -387,7 +387,7 @@ before
|
||||
name: "consecutive tool responses grouped",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Compute results"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: testArgs(map[string]any{"n": 1})}}}},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}},
|
||||
{Role: "tool", Content: "5", ToolName: "job"},
|
||||
{Role: "tool", Content: "6", ToolName: "job"},
|
||||
},
|
||||
@@ -412,7 +412,7 @@ ok
|
||||
name: "last message is tool then prefill",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "run"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: testArgs(map[string]any{"cmd": "ls"})}}}},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}},
|
||||
{Role: "tool", Content: "done", ToolName: "exec"},
|
||||
},
|
||||
expected: `<|im_start|>user
|
||||
@@ -447,7 +447,7 @@ done
|
||||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||
@@ -477,7 +477,7 @@ Thanks!<|im_end|>
|
||||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"},
|
||||
|
||||
@@ -128,10 +128,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "get-current-weather",
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// Arguments: map[string]any{
|
||||
// "location": "New York",
|
||||
// "unit": "fahrenheit",
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
@@ -148,7 +148,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"location"},
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// "location": {
|
||||
// Type: api.PropertyType{"string"},
|
||||
// Description: "The city and state, e.g. San Francisco, CA",
|
||||
@@ -158,7 +158,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// Enum: []any{"celsius", "fahrenheit"},
|
||||
// Description: "The temperature unit",
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
@@ -216,19 +216,19 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "add",
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// Arguments: map[string]any{
|
||||
// "a": 2,
|
||||
// "b": 3,
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "multiply",
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// Arguments: map[string]any{
|
||||
// "x": 4,
|
||||
// "y": 5,
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
@@ -257,10 +257,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"a", "b"},
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// "a": {Type: api.PropertyType{"integer"}, Description: "First number"},
|
||||
// "b": {Type: api.PropertyType{"integer"}, Description: "Second number"},
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
@@ -272,10 +272,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"x", "y"},
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// "x": {Type: api.PropertyType{"integer"}, Description: "First factor"},
|
||||
// "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"},
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// orderedArg represents a key-value pair for ordered argument creation
|
||||
type orderedArg struct {
|
||||
Key string
|
||||
Value any
|
||||
}
|
||||
|
||||
// testArgsOrdered creates ToolCallFunctionArguments with a specific key order
|
||||
func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for _, p := range pairs {
|
||||
args.Set(p.Key, p.Value)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// orderedProp represents a key-value pair for ordered property creation
|
||||
type orderedProp struct {
|
||||
Key string
|
||||
Value api.ToolProperty
|
||||
}
|
||||
|
||||
// testPropsOrdered creates a ToolPropertiesMap with a specific key order
|
||||
func testPropsOrdered(pairs []orderedProp) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for _, p := range pairs {
|
||||
props.Set(p.Key, p.Value)
|
||||
}
|
||||
return props
|
||||
}
|
||||
@@ -10,20 +10,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
@@ -173,9 +159,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 2,
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -183,9 +169,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 7,
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"timezone": "UTC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -229,7 +215,7 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" {
|
||||
if diff := cmp.Diff(original, toolCalls); diff != "" {
|
||||
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -925,7 +925,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1800,7 +1800,7 @@ func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StepBar displays step-based progress (e.g., for image generation steps).
|
||||
type StepBar struct {
|
||||
message string
|
||||
current int
|
||||
total int
|
||||
}
|
||||
|
||||
func NewStepBar(message string, total int) *StepBar {
|
||||
return &StepBar{message: message, total: total}
|
||||
}
|
||||
|
||||
func (s *StepBar) Set(current int) {
|
||||
s.current = current
|
||||
}
|
||||
|
||||
func (s *StepBar) String() string {
|
||||
percent := float64(s.current) / float64(s.total) * 100
|
||||
barWidth := s.total
|
||||
empty := barWidth - s.current
|
||||
|
||||
// "Generating 0% ▕ ▏ 0/9"
|
||||
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
|
||||
s.message, percent,
|
||||
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
|
||||
s.current, s.total)
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (p *Prompt) placeholder() string {
|
||||
}
|
||||
|
||||
type Terminal struct {
|
||||
reader *bufio.Reader
|
||||
outchan chan rune
|
||||
rawmode bool
|
||||
termios any
|
||||
}
|
||||
@@ -264,21 +264,36 @@ func NewTerminal() (*Terminal, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := UnsetRawMode(fd, termios); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Terminal{
|
||||
reader: bufio.NewReader(os.Stdin),
|
||||
outchan: make(chan rune),
|
||||
rawmode: true,
|
||||
termios: termios,
|
||||
}
|
||||
|
||||
go t.ioloop()
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Terminal) Read() (rune, error) {
|
||||
r, _, err := t.reader.ReadRune()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (t *Terminal) ioloop() {
|
||||
buf := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
close(t.outchan)
|
||||
break
|
||||
}
|
||||
t.outchan <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Terminal) Read() (rune, error) {
|
||||
r, ok := <-t.outchan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ const (
|
||||
CharCtrlL = 12
|
||||
CharEnter = 13
|
||||
CharNext = 14
|
||||
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
|
||||
CharPrev = 16
|
||||
CharBckSearch = 18
|
||||
CharFwdSearch = 19
|
||||
|
||||
@@ -3,7 +3,6 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -12,19 +11,12 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
if args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
|
||||
_build_darwin() {
|
||||
for ARCH in $ARCHS; do
|
||||
status "Building darwin $ARCH"
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
|
||||
if [ "$ARCH" = "amd64" ]; then
|
||||
status "Building darwin $ARCH dynamic backends"
|
||||
BUILD_DIR=build/darwin-$ARCH
|
||||
cmake -B $BUILD_DIR \
|
||||
cmake -B build/darwin-$ARCH \
|
||||
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
|
||||
-DMLX_ENGINE=ON \
|
||||
-DMLX_ENABLE_X64_MAC=ON \
|
||||
-DOLLAMA_RUNNER_DIR=./
|
||||
cmake --build $BUILD_DIR --target ggml-cpu -j
|
||||
cmake --build $BUILD_DIR --target mlx mlxc -j
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
-DOLLAMA_RUNNER_DIR=./ \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
cmake --build build/darwin-$ARCH --target ggml-cpu -j
|
||||
cmake --install build/darwin-$ARCH --component CPU
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,19 +61,17 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,25 +131,17 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -180,11 +149,11 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
@@ -208,7 +177,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -12,17 +12,6 @@ set -eu
|
||||
|
||||
. $(dirname $0)/env.sh
|
||||
|
||||
# Check for required tools
|
||||
if ! command -v zstd >/dev/null 2>&1; then
|
||||
echo "ERROR: zstd is required but not installed." >&2
|
||||
echo "Please install zstd:" >&2
|
||||
echo " - macOS: brew install zstd" >&2
|
||||
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
|
||||
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
|
||||
echo " - Arch: sudo pacman -S zstd" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
docker buildx build \
|
||||
@@ -48,27 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
echo "Compressing linux tar bundles..."
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
fi
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user