Compare commits

..

58 Commits

Author SHA1 Message Date
jmorganca
f7c5f40daa model: add lfm2 2026-01-20 01:56:45 -08:00
Daniel Hiltgen
c42e9d244f test: add image gen test case (#13698)
* test: fix type regression in tools test.

* test: add image gen integration test
2026-01-19 16:01:31 -08:00
Devon Rifkin
e98b5e8b4e /api/show: default to empty model_info (#13785)
For `/api/show`, a fully missing `model_info` field trips up various
integrators (including a recent Android Studio integration).

The primary source of missing info tends to come from models with a
remote that are also missing other data. It seems better to me to return
an empty `model_info` than making up some other fields within
`model_info` (like saying the architecture is `remote` or something like
that). So this does slightly change `/api/show`'s behavior that possibly
someone is relying on, but it seems more important to ensure the field
is always there (from a quick sampling integrations seem to be robust to
missing fields _within_ it).

Fixes: https://github.com/ollama/ollama/issues/13783
2026-01-19 15:26:17 -08:00
Jeffrey Morgan
68e00c7c36 fix: prevent image generation models from loading during deletion (#13781)
Move the unload check (empty prompt + KeepAlive=0) before the image
generation model dispatch in GenerateHandler. This prevents models like
flux from being loaded into memory just to be immediately unloaded when
running `ollama rm`.

Also fix a bug in DeleteHandler where `args[0]` was used instead of
`arg` in the delete loop, causing only the first model to be unloaded
when deleting multiple models.
2026-01-19 12:48:34 -08:00
Jeffrey Morgan
4f138a1749 model: add Glm4MoeLiteForCausalLM architecture to support GLM-4.7-Flash (#13779) 2026-01-19 12:47:17 -08:00
Jeffrey Morgan
03bf241c33 x/imagegen: add FP4 quantization support for image generation models (#13773)
Add --quantize fp4 support to ollama create for image generation models
(flux2, z-image-turbo), using MLX's affine 4-bit quantization.

Changes:
- Add fp4 to validation in CreateImageGenModel
- Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode)
- Add GetQuantization() to WeightSource interface for dynamic params
- Update LoadLinearLayer to use quantization params from model metadata
2026-01-19 00:54:54 -08:00
Jeffrey Morgan
a887406c24 x/imagegen: add preliminary support for FLUX.2-klein model (#13772) 2026-01-18 22:30:49 -08:00
Jeffrey Morgan
d51e95ba7e server: prevent image generation models from reloading on every request (#13771)
The loadImageGen function was not setting Options on the runnerRef,
causing needsReload() to always return true (since it checks if
runner.Options == nil). This resulted in the image generation
subprocess being killed and restarted for every request.
2026-01-18 20:50:04 -08:00
Jeffrey Morgan
3d01f2aa34 parsers: refactor Nemotron parser to reuse Qwen3Coder for tool calls (#13764)
Simplify Nemotron3NanoParser by delegating tool call parsing to
Qwen3CoderParser instead of duplicating the parsing logic. The
Nemotron parser now only handles the thinking state machine and
transitions to Qwen3CoderParser for content and tool call parsing.

This also fixes an issue where tool calls without </think> would
cause the parser to get stuck in thinking mode.
2026-01-17 18:28:52 -08:00
Jeffrey Morgan
634c416645 Add experimental image generation fields to /api/generate (#13753)
Request fields (experimental):
- width: image width (max 4096)
- height: image height (max 4096)
- steps: denoising steps
- seed: random seed

Response fields (experimental):
- images: base64-encoded generated images
- completed: current step progress
- total: total steps

Other changes:
- Fix lifecycle bug where image models wouldn't unload (refCount issue)
- Fix "headers already written" error on Ctrl+C during streaming
- Add gin middleware for OpenAI /v1/images/generations compatibility
- Update CLI to use /api/generate with progress bar
- Add preload support in interactive mode
2026-01-17 18:27:41 -08:00
Michael
57de86cc61 docs: update claude code docs (#13757)
* docs: update claude code docs
2026-01-16 22:41:34 -08:00
Daniel Hiltgen
12719b6e87 MLX - dynamic loading of mlx-c (#13735)
* MLX - dynamic loading of mlx-c

Create a wrapper layer to indirect the dependency on mlx-c so
the main ollama binary does not have a load-time dependency on mlx-c, mlx, and on linux, cuda.  Lazy load the library via dlopen
so we can adjust the path to ensure the dependencies are found
and fail gracefully if not present.

* review comments

* fix broken tests
2026-01-16 16:34:22 -08:00
Patrick Devine
a077d996e3 Fix create and show commands for experimental models (#13741)
* x: make `ollama create --experimental` import from safetensors

This change allows pulling in safetensors models into the new experimental model format, and also
fixes the `ollama show` command to be able to correctly display the model information.

* gofumpt the linter

* gofumpt the linter again

* validate the model name
2026-01-16 14:31:55 -08:00
Jeffrey Morgan
c23d5095de x/imagegen: clean up image generation code (#13725) 2026-01-16 12:19:25 -08:00
Bruce MacDonald
7601f0e93e server: reject unexpected auth hosts (#13738)
Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow.

Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com>

* gofmt

---------

Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
2026-01-16 14:10:36 -05:00
Eva H
aad3f03890 app: allow macOS app to terminate during system shutdown (#13737) 2026-01-16 09:05:04 -05:00
Gyungrai Wang
55d0b6e8b9 integration: fix tools_test.go for ToolCallFunctionArguments API change (#13731) 2026-01-15 16:08:09 -08:00
Devon Rifkin
38eac40d56 openai: tweak v1/responses to conform better (#13736)
* openai: tweak v1/responses to conform better

* openai: provide better error for image URLs

* lint
2026-01-15 15:46:36 -08:00
Jeffrey Morgan
80f3f1bc25 readme: add instructions to build with MLX (#13733) 2026-01-15 11:03:52 -08:00
Parth Sareen
b1a0db547b docs: add env var needed for claude code in docs (#13721) 2026-01-15 10:11:00 -08:00
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
vincent d warmerdam
349d814814 docs: add marimo integration (#13326)
* docs added

* fix title

* add marimo to docs.json

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:37:38 -08:00
Yuhong Sun
c8743031e0 docs: add onyx integration (#13135)
* Ready for team review

* Update docs/integrations/onyx.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* update docs.json

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:32:05 -08:00
Jeffrey Morgan
4adb9cf4bb scripts: fix macOS auto-update signature verification failure (#13713)
Add --norsrc flag to ditto commands when creating Ollama-darwin.zip
to exclude AppleDouble resource fork files (._* files) from the archive.

The mlx.metallib file has extended attributes, which causes ditto to
include a ._mlx.metallib AppleDouble file in the zip. Since this file
is not part of the code signature seal, macOS rejects the bundle during
auto-update verification with:

  "a sealed resource is missing or invalid"
  "file added: .../._mlx.metallib"

The --norsrc flag prevents ditto from preserving resource forks and
extended attributes, ensuring only signed files are included in the
release archive.
2026-01-14 07:48:10 -08:00
Daniel Hiltgen
74f475e735 Revert "Documentation edits made through Mintlify web editor" (#13688)
This reverts commit c6d4c0c7f2.

Merge after 0.14.0 ships for the updated Linux documentation.
2026-01-14 07:42:34 -08:00
Maternion
875cecba74 docs: update default context window size to 4096 tokens (#13709) 2026-01-14 01:01:28 -08:00
Josh Daniel Bañares
7d411a4686 docs: update web search param in examples (#13711) 2026-01-14 00:38:39 -08:00
Daniel Hiltgen
02a2401596 mlx: bundle openblas dependency (#13706) 2026-01-13 15:29:47 -08:00
Daniel Hiltgen
e4b488a7b5 CI: dedup cuda libraries to reduce payload size (#13704) 2026-01-13 11:25:31 -08:00
Daniel Hiltgen
98079ddd79 ci: add missing mlx components to release build (#13702) 2026-01-13 09:13:09 -08:00
Jeffrey Morgan
d70942f47b x/imagegen/cli: skip local model check (#13699) 2026-01-12 22:38:10 -08:00
Jeffrey Morgan
58e4701557 scripts: increase notarization timeout to 20m (#13697)
The 100MB mlx.metallib file significantly increased the app bundle size,
causing Apple's notarization service to timeout with the previous 10m limit.
2026-01-12 20:38:38 -08:00
Jeffrey Morgan
dbf47ee55a cmake: use CMAKE_SYSTEM_PROCESSOR instead of CMAKE_OSX_ARCHITECTURES for mlx.metallib install (#13696)
The CMake condition for installing mlx.metallib checks
CMAKE_OSX_ARCHITECTURES, but this variable is only set when explicitly
passed - not auto-detected. The arm64 build was missing this flag,
causing the metallib to not be installed, which then caused codesign
to fail on the unexpanded glob pattern.
2026-01-12 20:05:11 -08:00
Jeffrey Morgan
af7ea6e96e x/imagegen: install mlx.metallib and fix macOS rpath handling, add mlx library directories to LD_LIBRARY_PATH (#13695)
- Install mlx.metallib for arm64 builds (required for Metal GPU acceleration)
- Apply rpath settings to all macOS builds, not just x86_64
- Add CMAKE_BUILD_WITH_INSTALL_RPATH to avoid install_name_tool errors
- Update build_darwin.sh to copy, sign, and package the metallib
2026-01-12 19:03:11 -08:00
Jeffrey Morgan
8f1e0140e7 x/imagegen: fix mlx build in Dockerfile and macOS build script (#13693) 2026-01-12 15:52:43 -08:00
Parth Sareen
35c3c9e3c2 anthropic: allow non-thinking models when using Anthropic API (#13692) 2026-01-12 15:13:26 -08:00
Parth Sareen
d06acbcb19 x/cmd: enable web search and web fetch with flag (#13690) 2026-01-12 13:59:40 -08:00
Jeffrey Morgan
9667c2282f x/imagegen: add naive TeaCache and FP8 quantization support (#13683)
TeaCache:
- Timestep embedding similarity caching for diffusion models
- Polynomial rescaling with configurable thresholds
- Reduces transformer forward passes by ~30-50%

FP8 quantization:
- Support for FP8 quantized models (8-bit weights with scales)
- QuantizedMatmul on Metal, Dequantize on CUDA
- Client-side quantization via ollama create --quantize fp8

Other bug fixes:
- Fix `/api/show` API for image generation models
- Server properly returns model info (architecture, parameters, quantization)
- Memory allocation optimizations
- CLI improvements for image generation
2026-01-12 13:45:22 -08:00
Jeffrey Morgan
a937a68317 server: fix slow 'ollama rm' of models with many layers (#13680)
RemoveLayers was calling Manifests() for each layer to check if it was
shared with other models. For models with many blobs (e.g., tensor
models), this caused O(N*M) manifest reads.

Now loads manifests once and builds a set of in-use digests.
2026-01-12 13:17:48 -08:00
Parth Sareen
2185112d84 x/cmd: connect /set flags to behavior in experimental mode (#13684) 2026-01-12 00:40:44 -08:00
Parth Sareen
91926601dc x: add missing /set, /show, /load, /save commands to experimental mode (#13682) 2026-01-11 23:12:31 -08:00
Jeffrey Morgan
361d6c16c2 x/imagegen/transfer: fix timeout and progress reporting (#13679)
Removes 5-minute HTTP client timeout that caused "context deadline
exceeded" errors on large file downloads. Stall detection (10s)
already handles unresponsive connections.

Fixes progress bar total going down on resume by calculating total
from all blobs upfront and reporting already-downloaded bytes
as completed immediately.
2026-01-11 15:33:53 -08:00
Patrick Devine
7e2496e88e Fix cmake install command in README (#13678)
Update installation command for MLX component in README.
2026-01-11 13:16:42 -08:00
WhatToPutHere
5b84e29882 docs: fix troubleshooting page (#13674)
Updated the link in the log output description to point to the correct troubleshooting guide format.
2026-01-11 00:58:07 -08:00
Jeffrey Morgan
7cc2a653f2 dockerfile: remove unused COPY command (#13664) 2026-01-09 23:07:15 -08:00
Jeffrey Morgan
2584940016 Add z-image image generation prototype (#13659) 2026-01-09 21:09:46 -08:00
Michael
c6d4c0c7f2 Documentation edits made through Mintlify web editor 2026-01-09 21:29:03 -05:00
Parth Sareen
1ef4241727 x: request access for all commands, add welcome message (#13662) 2026-01-09 18:20:39 -08:00
Parth Sareen
68fafd3002 x: improve approval selector with clearer labels (#13663) 2026-01-09 17:08:12 -08:00
Parth Sareen
2b2cda7a2b api: implement anthropic api (#13600)
* api: add Anthropic Messages API compatibility layer

Add middleware to support the Anthropic Messages API format at /v1/messages.
This enables tools like Claude Code to work with Ollama local and cloud models through the
Anthropic API interface.
2026-01-09 11:53:36 -08:00
Daniel Hiltgen
3cfe9fe146 docker: add missing deps (#13654)
The new MLX library has extra dependencies.
2026-01-09 07:34:40 -08:00
Parth Sareen
a23b559b4c x: disable web search tool registration (#13656) 2026-01-09 01:42:20 -08:00
Daniel Hiltgen
33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00
Daniel Hiltgen
34d0c55ea5 Linux: switch to zstd compression (#13651)
With the upcoming addition of MLX, the linux bundle will exceed the
maximum github artifact size of 2G.  This change will bring the size
back down.

The install.sh changes support backwards compatibility for prior versions
thus should be safe to merge concurrently with this change.
2026-01-08 15:47:32 -08:00
Parth Sareen
53a5a9e9ae x: redesign agent UI with minimal styling (#13650) 2026-01-08 15:40:07 -08:00
Parth Sareen
e30e08a7d6 x: remove Ctrl+O tool output expansion feature (#13640) 2026-01-07 15:34:08 -08:00
Parth Sareen
12e2b3514a x: agent loop ux improvements (#13635) 2026-01-07 01:27:15 -08:00
Devon Rifkin
626af2d809 template: fix args-as-json rendering (#13636)
In #13525, I accidentally broke templates' ability to automatically
render tool call function arguments as JSON.

We do need these to be proper maps because we need templates to be able
to call range, which can't be done on custom types.
2026-01-06 18:33:57 -08:00
277 changed files with 60121 additions and 5028 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -68,6 +68,7 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -371,13 +372,17 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -392,13 +397,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tgz
*.tar.zst
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -531,7 +536,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -32,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
if(APPLE)
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -147,14 +164,56 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,6 +83,28 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -140,6 +162,21 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache \
&& dnf install -y clang ccache git \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
@@ -131,8 +131,32 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -141,18 +165,21 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ARG CGO_CXXFLAGS
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -171,7 +198,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

1
MLX_VERSION Normal file
View File

@@ -0,0 +1 @@
v0.4.1

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
```shell
go build -tags mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

778
anthropic/anthropic.go Normal file
View File

@@ -0,0 +1,778 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

953
anthropic/anthropic_test.go Normal file
View File

@@ -0,0 +1,953 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 8 * format.MegaByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -19,12 +19,6 @@ import (
"github.com/ollama/ollama/types/model"
)
// SkillRef is an alias for model.SkillRef representing a skill reference.
type SkillRef = model.SkillRef
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
type MCPRef = model.MCPRef
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
@@ -133,6 +127,20 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -696,18 +704,6 @@ type CreateRequest struct {
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Skills is a list of skill references for the agent (local paths or registry refs)
Skills []SkillRef `json:"skills,omitempty"`
// MCPs is a list of MCP server references for the agent
MCPs []MCPRef `json:"mcps,omitempty"`
// AgentType defines the type of agent (e.g., "conversational", "task-based")
AgentType string `json:"agent_type,omitempty"`
// Entrypoint specifies an external command to run instead of the built-in chat loop
Entrypoint string `json:"entrypoint,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -753,16 +749,12 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
Skills []SkillRef `json:"skills,omitempty"`
MCPs []MCPRef `json:"mcps,omitempty"`
AgentType string `json:"agent_type,omitempty"`
Entrypoint string `json:"entrypoint,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].
@@ -882,6 +874,20 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Image contains a base64-encoded generated image.
// Only present for image generation models.
Image string `json:"image,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end
@implementation AppDelegate
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) {
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES];
}
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel;

View File

@@ -1,402 +0,0 @@
package cmd
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// TestToolMessage verifies that tool messages are constructed correctly
// with ToolName and ToolCallID preserved from the tool call.
func TestToolMessage(t *testing.T) {
tests := []struct {
name string
call api.ToolCall
content string
expected api.Message
}{
{
name: "basic tool message with ID",
call: api.ToolCall{
ID: "call_abc123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
content: "Sunny, 22°C",
expected: api.Message{
Role: "tool",
Content: "Sunny, 22°C",
ToolName: "get_weather",
ToolCallID: "call_abc123",
},
},
{
name: "tool message without ID",
call: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: api.ToolCallFunctionArguments{
"expression": "2+2",
},
},
},
content: "4",
expected: api.Message{
Role: "tool",
Content: "4",
ToolName: "calculate",
// ToolCallID should be empty when call.ID is empty
},
},
{
name: "MCP tool message",
call: api.ToolCall{
ID: "call_mcp123",
Function: api.ToolCallFunction{
Name: "mcp_websearch_search",
Arguments: api.ToolCallFunctionArguments{
"query": "ollama agents",
},
},
},
content: "Found 10 results",
expected: api.Message{
Role: "tool",
Content: "Found 10 results",
ToolName: "mcp_websearch_search",
ToolCallID: "call_mcp123",
},
},
{
name: "skill tool message",
call: api.ToolCall{
ID: "call_skill456",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "calculator",
"command": "python scripts/calc.py 2+2",
},
},
},
content: "Result: 4",
expected: api.Message{
Role: "tool",
Content: "Result: 4",
ToolName: "run_skill_script",
ToolCallID: "call_skill456",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toolMessage(tt.call, tt.content)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("toolMessage() mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestAssistantMessageWithThinking verifies that assistant messages
// in the tool loop should include thinking content.
func TestAssistantMessageConstruction(t *testing.T) {
tests := []struct {
name string
content string
thinking string
toolCalls []api.ToolCall
expectedMsg api.Message
}{
{
name: "assistant with thinking and tool calls",
content: "",
thinking: "I need to check the weather for Paris.",
toolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for Paris.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
},
},
},
{
name: "assistant with content, thinking, and tool calls",
content: "Let me check that for you.",
thinking: "User wants weather info.",
toolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "Let me check that for you.",
Thinking: "User wants weather info.",
ToolCalls: []api.ToolCall{
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{"query": "weather"},
},
},
},
},
},
{
name: "assistant with multiple tool calls",
content: "",
thinking: "I'll check both cities.",
toolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
expectedMsg: api.Message{
Role: "assistant",
Content: "",
Thinking: "I'll check both cities.",
ToolCalls: []api.ToolCall{
{
ID: "call_a",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
},
},
{
ID: "call_b",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"city": "London"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the assistant message construction as done in chat()
assistantMsg := api.Message{
Role: "assistant",
Content: tt.content,
Thinking: tt.thinking,
ToolCalls: tt.toolCalls,
}
if diff := cmp.Diff(tt.expectedMsg, assistantMsg); diff != "" {
t.Errorf("assistant message mismatch (-want +got):\n%s", diff)
}
})
}
}
// TestMessageStitchingOrder verifies that messages in a tool loop
// are stitched in the correct order:
// 1. User message
// 2. Assistant message with tool calls (and thinking)
// 3. Tool result messages (one per tool call, in order)
// 4. Next assistant response
func TestMessageStitchingOrder(t *testing.T) {
// Simulate a complete tool loop conversation
messages := []api.Message{
// Initial user message
{Role: "user", Content: "What's the weather in Paris and London?"},
// Assistant's first response with tool calls
{
Role: "assistant",
Content: "",
Thinking: "I need to check the weather for both cities.",
ToolCalls: []api.ToolCall{
{ID: "call_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
{ID: "call_2", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
},
},
// Tool results (in order matching tool calls)
{Role: "tool", Content: "Sunny, 22°C", ToolName: "get_weather", ToolCallID: "call_1"},
{Role: "tool", Content: "Rainy, 15°C", ToolName: "get_weather", ToolCallID: "call_2"},
// Final assistant response
{Role: "assistant", Content: "Paris is sunny at 22°C, and London is rainy at 15°C.", Thinking: "Got the data, now summarizing."},
}
// Verify structure
expectedRoles := []string{"user", "assistant", "tool", "tool", "assistant"}
for i, msg := range messages {
if msg.Role != expectedRoles[i] {
t.Errorf("message %d: expected role %q, got %q", i, expectedRoles[i], msg.Role)
}
}
// Verify tool results match tool calls in order
assistantWithTools := messages[1]
toolResults := []api.Message{messages[2], messages[3]}
if len(toolResults) != len(assistantWithTools.ToolCalls) {
t.Errorf("expected %d tool results for %d tool calls", len(assistantWithTools.ToolCalls), len(toolResults))
}
for i, result := range toolResults {
expectedToolCallID := assistantWithTools.ToolCalls[i].ID
if result.ToolCallID != expectedToolCallID {
t.Errorf("tool result %d: expected ToolCallID %q, got %q", i, expectedToolCallID, result.ToolCallID)
}
expectedToolName := assistantWithTools.ToolCalls[i].Function.Name
if result.ToolName != expectedToolName {
t.Errorf("tool result %d: expected ToolName %q, got %q", i, expectedToolName, result.ToolName)
}
}
// Verify thinking is present in assistant messages
if messages[1].Thinking == "" {
t.Error("first assistant message should have thinking content")
}
if messages[4].Thinking == "" {
t.Error("final assistant message should have thinking content")
}
}
// TestMultiTurnToolLoop verifies message stitching across multiple
// tool call iterations.
func TestMultiTurnToolLoop(t *testing.T) {
messages := []api.Message{
{Role: "user", Content: "What's 2+2 and also what's the weather in Paris?"},
// First tool call: calculate
{
Role: "assistant",
Thinking: "I'll start with the calculation.",
ToolCalls: []api.ToolCall{
{ID: "calc_1", Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expr": "2+2"}}},
},
},
{Role: "tool", Content: "4", ToolName: "calculate", ToolCallID: "calc_1"},
// Second tool call: weather
{
Role: "assistant",
Thinking: "Got the calculation. Now checking weather.",
ToolCalls: []api.ToolCall{
{ID: "weather_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
},
},
{Role: "tool", Content: "Sunny, 20°C", ToolName: "get_weather", ToolCallID: "weather_1"},
// Final response
{Role: "assistant", Content: "2+2 equals 4, and Paris is sunny at 20°C."},
}
// Count message types
roleCounts := map[string]int{}
for _, msg := range messages {
roleCounts[msg.Role]++
}
if roleCounts["user"] != 1 {
t.Errorf("expected 1 user message, got %d", roleCounts["user"])
}
if roleCounts["assistant"] != 3 {
t.Errorf("expected 3 assistant messages, got %d", roleCounts["assistant"])
}
if roleCounts["tool"] != 2 {
t.Errorf("expected 2 tool messages, got %d", roleCounts["tool"])
}
// Verify each tool message follows an assistant with matching tool call
for i, msg := range messages {
if msg.Role == "tool" {
// Find preceding assistant message with tool calls
var precedingAssistant *api.Message
for j := i - 1; j >= 0; j-- {
if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 {
precedingAssistant = &messages[j]
break
}
}
if precedingAssistant == nil {
t.Errorf("tool message at index %d has no preceding assistant with tool calls", i)
continue
}
// Verify tool result matches one of the tool calls
found := false
for _, tc := range precedingAssistant.ToolCalls {
if tc.ID == msg.ToolCallID {
found = true
break
}
}
if !found {
t.Errorf("tool message at index %d has ToolCallID %q not found in preceding tool calls", i, msg.ToolCallID)
}
}
}
}
// TestSkillCatalogRunToolCallPreservesFields tests that skill catalog
// returns tool messages with correct fields.
func TestSkillCatalogToolMessageFields(t *testing.T) {
// Create a minimal test for toolMessage function
call := api.ToolCall{
ID: "test_id_123",
Function: api.ToolCallFunction{
Name: "run_skill_script",
Arguments: api.ToolCallFunctionArguments{
"skill": "test-skill",
"command": "echo hello",
},
},
}
msg := toolMessage(call, "hello")
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.Content != "hello" {
t.Errorf("expected content 'hello', got %q", msg.Content)
}
if msg.ToolName != "run_skill_script" {
t.Errorf("expected ToolName 'run_skill_script', got %q", msg.ToolName)
}
if msg.ToolCallID != "test_id_123" {
t.Errorf("expected ToolCallID 'test_id_123', got %q", msg.ToolCallID)
}
}

View File

@@ -15,7 +15,6 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -47,6 +46,9 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -92,11 +94,88 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Validate model name early to fail fast
modelName := args[0]
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -128,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -458,6 +537,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -496,16 +576,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
// Check if this is an agent
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
if isAgent {
opts.IsAgent = true
opts.AgentType = info.AgentType
opts.Skills = info.Skills
opts.MCPs = info.MCPs
opts.Entrypoint = info.Entrypoint
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -529,12 +599,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImage) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
// If agent has entrypoint, run it instead of chat loop
if opts.Entrypoint != "" {
return runEntrypoint(cmd, opts)
}
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -562,69 +638,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts)
}
// For agents, use chat API even in non-interactive mode to support tools
if opts.IsAgent {
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
_, err := chat(cmd, opts)
return err
}
return generate(cmd, opts)
}
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
entrypoint := opts.Entrypoint
// Check if entrypoint contains $PROMPT placeholder
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
if hasPlaceholder && opts.Prompt != "" {
// Replace $PROMPT with the actual prompt
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
} else if hasPlaceholder {
// No prompt provided but placeholder exists - remove placeholder
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
}
// Parse entrypoint into command and args
parts := strings.Fields(entrypoint)
if len(parts) == 0 {
return fmt.Errorf("empty entrypoint")
}
command := parts[0]
args := parts[1:]
// If user provided a prompt and no placeholder was used, append it as argument
if opts.Prompt != "" && !hasPlaceholder {
args = append(args, opts.Prompt)
}
// Look up command in PATH
execPath, err := exec.LookPath(command)
if err != nil {
return fmt.Errorf("entrypoint command not found: %s", command)
}
// Create subprocess
proc := exec.Command(execPath, args...)
proc.Stdin = os.Stdin
proc.Stdout = os.Stdout
proc.Stderr = os.Stderr
// Run and wait
return proc.Run()
}
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -723,7 +746,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -872,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, arg := range args {
// Unload the model if it's running before deletion
if err := loadOrUnloadModel(cmd, &runOptions{
Model: args[0],
Model: arg,
KeepAlive: &api.Duration{Duration: 0},
}); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
}
}
@@ -984,96 +1011,47 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
fmt.Fprintln(w)
}
// Only show Model section if there's actual model info (not for entrypoint-only agents)
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
if hasModelInfo {
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
}
// Display agent information if this is an agent
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
tableRender("Agent", func() (rows [][]string) {
if resp.AgentType != "" {
rows = append(rows, []string{"", "type", resp.AgentType})
}
if resp.Entrypoint != "" {
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
}
if len(resp.Skills) > 0 {
for i, skill := range resp.Skills {
label := "skill"
if i > 0 {
label = ""
}
// Show skill name or digest
skillDisplay := skill.Name
if skillDisplay == "" && skill.Digest != "" {
skillDisplay = skill.Digest[:12] + "..."
}
rows = append(rows, []string{"", label, skillDisplay})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
if len(resp.MCPs) > 0 {
for i, mcp := range resp.MCPs {
label := "mcp"
if i > 0 {
label = ""
}
// Show MCP name and command
mcpDisplay := mcp.Name
if mcp.Command != "" {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
mcpDisplay += " (" + cmdLine + ")"
}
rows = append(rows, []string{"", label, mcpDisplay})
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
return
})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
@@ -1315,11 +1293,6 @@ type runOptions struct {
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
IsAgent bool
AgentType string
Skills []api.SkillRef
MCPs []api.MCPRef
Entrypoint string
}
func (r runOptions) Copy() runOptions {
@@ -1349,12 +1322,6 @@ func (r runOptions) Copy() runOptions {
think = &cThink
}
var skills []api.SkillRef
if r.Skills != nil {
skills = make([]api.SkillRef, len(r.Skills))
copy(skills, r.Skills)
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
@@ -1370,9 +1337,6 @@ func (r runOptions) Copy() runOptions {
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
IsAgent: r.IsAgent,
AgentType: r.AgentType,
Skills: skills,
}
}
@@ -1456,65 +1420,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return nil, err
}
// Load skills for agents
var skillsCatalog *skillCatalog
if opts.IsAgent && len(opts.Skills) > 0 {
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
if err != nil {
return nil, fmt.Errorf("failed to load skills: %w", err)
}
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
var skillNames []string
for _, s := range skillsCatalog.Skills {
skillNames = append(skillNames, s.Name)
}
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
}
}
// Load MCP servers for agents (from opts and global config)
var mcpMgr *mcpManager
allMCPs := opts.MCPs
// Load global MCPs from ~/.ollama/mcp.json
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
for name, srv := range globalConfig.MCPServers {
// Skip disabled MCPs
if srv.Disabled {
continue
}
// Check if already in opts.MCPs (model takes precedence)
found := false
for _, m := range opts.MCPs {
if m.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
}
}
if len(allMCPs) > 0 {
mcpMgr = newMCPManager()
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
}
if mcpMgr.ToolCount() > 0 {
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
}
defer mcpMgr.Shutdown()
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
@@ -1538,7 +1443,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var fullResponse strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
role := "assistant"
@@ -1579,13 +1483,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
if response.Message.ToolCalls != nil {
toolCalls := response.Message.ToolCalls
if len(toolCalls) > 0 {
if skillsCatalog != nil || mcpMgr != nil {
// Store tool calls for execution after response is complete
pendingToolCalls = append(pendingToolCalls, toolCalls...)
} else {
// No skills catalog or MCP, just display tool calls
fmt.Print(renderToolCalls(toolCalls, false))
}
fmt.Print(renderToolCalls(toolCalls, false))
}
}
@@ -1598,161 +1496,31 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
opts.Format = `"` + opts.Format + `"`
}
// Prepare messages with agent-specific system prompt
messages := opts.Messages
if skillsCatalog != nil {
// Add skills system prompt as the first system message
skillsPrompt := skillsCatalog.SystemPrompt()
if skillsPrompt != "" {
// Insert skills prompt at the beginning, or append to existing system message
if len(messages) > 0 && messages[0].Role == "system" {
// Append to existing system message
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
} else {
// Insert new system message at the beginning
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
messages = append([]api.Message{systemMsg}, messages...)
}
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// Add tools for agents (combine skills and MCP tools)
var allTools api.Tools
if skillsCatalog != nil {
allTools = append(allTools, skillsCatalog.Tools()...)
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
if mcpMgr != nil {
allTools = append(allTools, mcpMgr.Tools()...)
}
if len(allTools) > 0 {
req.Tools = allTools
}
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err
}
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
break
}
// Execute tool calls and continue the conversation
fmt.Fprintf(os.Stderr, "\n")
// Add assistant's tool call message to history (include thinking for proper rendering)
assistantMsg := api.Message{
Role: "assistant",
Content: fullResponse.String(),
Thinking: thinkingContent.String(),
ToolCalls: pendingToolCalls,
}
messages = append(messages, assistantMsg)
// Execute each tool call and collect results
var toolResults []api.Message
for _, call := range pendingToolCalls {
// Show what's being executed
switch call.Function.Name {
case "run_skill_script":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
commandVal, _ := call.Function.Arguments.Get("command")
command, _ := commandVal.(string)
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
case "read_skill_file":
skillVal, _ := call.Function.Arguments.Get("skill")
skill, _ := skillVal.(string)
pathVal, _ := call.Function.Arguments.Get("path")
path, _ := pathVal.(string)
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
default:
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
}
var result api.Message
var handled bool
var err error
// Try skill catalog first
if skillsCatalog != nil {
result, handled, err = skillsCatalog.RunToolCall(call)
}
// If not handled by skills, try MCP
if !handled && mcpMgr != nil {
result, handled, err = mcpMgr.RunToolCall(call)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
// Add error result
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
})
continue
}
if !handled {
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
})
continue
}
// Display tool output
if result.Content != "" {
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
}
// Add tool result to messages (preserves ToolName, ToolCallID from result)
toolResults = append(toolResults, result)
}
// Add tool results to message history
messages = append(messages, toolResults...)
fmt.Fprintf(os.Stderr, "\n")
// Reset state for next iteration
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
// Start new progress spinner for next API call
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
return nil, err
}
if len(opts.Messages) > 0 {
@@ -2047,15 +1815,22 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -2091,6 +1866,11 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -2205,6 +1985,7 @@ func NewCLI() *cobra.Command {
} {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
@@ -2245,8 +2026,6 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
NewSkillCommand(),
NewMCPCommand(),
)
return rootCmd

View File

@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -34,9 +34,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /skills Show available skills")
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
@@ -119,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
@@ -447,411 +444,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usageShow()
}
case strings.HasPrefix(line, "/skill "):
args := strings.Fields(line)
if len(args) < 2 {
fmt.Fprintln(os.Stderr, "Usage:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current skills")
continue
}
switch args[1] {
case "add":
if len(args) < 3 {
fmt.Println("Usage: /skill add <path>")
continue
}
skillPath := args[2]
// Expand ~ to home directory
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Error expanding path: %v\n", err)
continue
}
skillPath = filepath.Join(home, skillPath[1:])
}
// Make absolute
absPath, err := filepath.Abs(skillPath)
if err != nil {
fmt.Printf("Error resolving path: %v\n", err)
continue
}
// Verify SKILL.md exists
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
continue
}
// Extract skill name from SKILL.md
content, err := os.ReadFile(skillMdPath)
if err != nil {
fmt.Printf("Error reading SKILL.md: %v\n", err)
continue
}
skillName, _ := extractSkillMetadata(string(content))
if skillName == "" {
skillName = filepath.Base(absPath)
}
// Check if already added
for _, s := range opts.Skills {
if s.Name == skillName {
fmt.Printf("Skill '%s' is already loaded\n", skillName)
continue
}
}
// Add to skills (using path as Name, no digest for local skills)
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
opts.IsAgent = true // Enable agent mode if not already
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /skill remove <name>")
continue
}
skillName := args[2]
found := false
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
for _, s := range opts.Skills {
// Match by name or by path basename
name := s.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name)
}
if name == skillName || s.Name == skillName {
found = true
fmt.Printf("Removed skill '%s'\n", skillName)
} else {
newSkills = append(newSkills, s)
}
}
if !found {
fmt.Printf("Skill '%s' not found\n", skillName)
} else {
opts.Skills = newSkills
}
case "list", "ls":
if len(opts.Skills) == 0 {
fmt.Println("No skills loaded in this session.")
} else {
fmt.Println("Skills loaded in this session:")
for _, skill := range opts.Skills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
// For local paths, show basename
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (local: " + skill.Name + ")"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
default:
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
}
continue
case strings.HasPrefix(line, "/skills"):
// Show skills from model (bundled) + session skills
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model skills with session skills
allSkills := make([]api.SkillRef, 0)
allSkills = append(allSkills, resp.Skills...)
// Add session skills that aren't already in model skills
for _, sessionSkill := range opts.Skills {
found := false
for _, modelSkill := range resp.Skills {
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
found = true
break
}
}
if !found {
allSkills = append(allSkills, sessionSkill)
}
}
if len(allSkills) == 0 {
fmt.Println("No skills available.")
} else {
fmt.Println("Available Skills:")
for _, skill := range allSkills {
if skill.Digest != "" {
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
} else {
name := skill.Name
if strings.Contains(name, string(os.PathSeparator)) {
name = filepath.Base(name) + " (session)"
}
fmt.Printf(" %s\n", name)
}
}
}
fmt.Println()
continue
case strings.HasPrefix(line, "/mcp"):
args := strings.Fields(line)
// If just "/mcp" with no args, show all MCP servers
if len(args) == 1 {
// Show MCPs from model (bundled) + global config
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model info")
return err
}
// Combine model MCPs with global config MCPs
allMCPs := make([]api.MCPRef, 0)
allMCPs = append(allMCPs, resp.MCPs...)
// Load global config
globalConfig, _ := loadMCPConfig()
globalMCPNames := make(map[string]bool)
if globalConfig != nil {
for name, srv := range globalConfig.MCPServers {
// Check if already in model MCPs
found := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == name {
found = true
break
}
}
if !found {
allMCPs = append(allMCPs, api.MCPRef{
Name: name,
Command: srv.Command,
Args: srv.Args,
Env: srv.Env,
Type: srv.Type,
})
}
globalMCPNames[name] = true
}
}
if len(allMCPs) == 0 {
fmt.Println("No MCP servers available.")
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
} else {
fmt.Println("Available MCP Servers:")
for _, mcp := range allMCPs {
cmdLine := mcp.Command
if len(mcp.Args) > 0 {
cmdLine += " " + strings.Join(mcp.Args, " ")
}
source := ""
disabled := ""
// Check if it's from model or global config
isFromModel := false
for _, modelMCP := range resp.MCPs {
if modelMCP.Name == mcp.Name {
isFromModel = true
break
}
}
if isFromModel {
source = " (model)"
} else if globalMCPNames[mcp.Name] {
source = " (global)"
// Check if disabled
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
disabled = " [disabled]"
}
}
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
}
}
fmt.Println()
continue
}
switch args[1] {
case "add":
if len(args) < 4 {
fmt.Println("Usage: /mcp add <name> <command> [args...]")
continue
}
mcpName := args[2]
mcpCommand := args[3]
mcpArgs := args[4:]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
// Check if already exists
if _, exists := config.MCPServers[mcpName]; exists {
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
}
// Add to global config
config.MCPServers[mcpName] = MCPServerConfig{
Type: "stdio",
Command: mcpCommand,
Args: mcpArgs,
}
// Save config
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
cmdLine := mcpCommand
if len(mcpArgs) > 0 {
cmdLine += " " + strings.Join(mcpArgs, " ")
}
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
fmt.Println("Note: MCP server will be started on next message.")
case "remove", "rm":
if len(args) < 3 {
fmt.Println("Usage: /mcp remove <name>")
continue
}
mcpName := args[2]
// Load global config
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
if _, exists := config.MCPServers[mcpName]; !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
delete(config.MCPServers, mcpName)
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
fmt.Println("Note: Changes will take effect on next message.")
case "disable":
if len(args) < 3 {
fmt.Println("Usage: /mcp disable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if srv.Disabled {
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
continue
}
srv.Disabled = true
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
case "enable":
if len(args) < 3 {
fmt.Println("Usage: /mcp enable <name>")
continue
}
mcpName := args[2]
config, err := loadMCPConfig()
if err != nil {
fmt.Printf("Error loading MCP config: %v\n", err)
continue
}
srv, exists := config.MCPServers[mcpName]
if !exists {
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
continue
}
if !srv.Disabled {
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
continue
}
srv.Disabled = false
config.MCPServers[mcpName] = srv
if err := saveMCPConfig(config); err != nil {
fmt.Printf("Error saving MCP config: %v\n", err)
continue
}
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
fmt.Println("Note: Changes will take effect on next message.")
default:
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
}
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
args := strings.Fields(line)
if len(args) > 1 {
@@ -860,20 +452,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usageSet()
case "show", "/show":
usageShow()
case "skill", "/skill":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
fmt.Fprintln(os.Stderr, "")
case "mcp", "/mcp":
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
fmt.Fprintln(os.Stderr, "")
case "shortcut", "shortcuts":
usageShortcuts()
}

View File

@@ -1,570 +0,0 @@
package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
)
// SkillPushHandler handles the skill push command.
func SkillPushHandler(cmd *cobra.Command, args []string) error {
if len(args) != 2 {
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
}
name := args[0]
path := args[1]
// Expand path
if strings.HasPrefix(path, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("expanding home directory: %w", err)
}
path = filepath.Join(home, path[1:])
}
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("resolving path: %w", err)
}
// Validate skill directory
skillMdPath := filepath.Join(absPath, "SKILL.md")
if _, err := os.Stat(skillMdPath); err != nil {
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
}
// Parse skill name (will set Kind="skill")
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Create skill layer
displayName := n.DisplayShortest()
status := fmt.Sprintf("Creating skill layer for %s", displayName)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
layer, err := server.CreateSkillLayer(absPath)
if err != nil {
return fmt.Errorf("creating skill layer: %w", err)
}
spinner.Stop()
// Create skill manifest
manifest, configLayer, err := createSkillManifest(absPath, layer)
if err != nil {
return fmt.Errorf("creating skill manifest: %w", err)
}
// Write manifest locally
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return fmt.Errorf("creating manifest directory: %w", err)
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return fmt.Errorf("marshaling manifest: %w", err)
}
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
return fmt.Errorf("writing manifest: %w", err)
}
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
// Push to registry
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
// For now, we'll use the existing push mechanism
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
req := &api.PushRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Push(context.Background(), req, fn); err != nil {
// If push fails, still show success for local creation
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
return nil
}
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
return nil
}
// SkillPullHandler handles the skill pull command.
func SkillPullHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
insecure, _ := cmd.Flags().GetBool("insecure")
p := progress.NewProgress(os.Stderr)
defer p.Stop()
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
p.Add(resp.Digest, bar)
} else if resp.Status != "" {
spinner := progress.NewSpinner(resp.Status)
p.Add(resp.Status, spinner)
}
return nil
}
displayName := n.DisplayShortest()
req := &api.PullRequest{
Model: displayName,
Insecure: insecure,
}
if err := client.Pull(context.Background(), req, fn); err != nil {
return fmt.Errorf("pulling skill: %w", err)
}
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
return nil
}
// SkillListHandler handles the skill list command.
func SkillListHandler(cmd *cobra.Command, args []string) error {
skills, err := listLocalSkills()
if err != nil {
return fmt.Errorf("listing skills: %w", err)
}
if len(skills) == 0 {
fmt.Println("No skills installed")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
for _, skill := range skills {
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
skill.Namespace,
skill.Name,
skill.Tag,
format.HumanBytes(skill.Size),
format.HumanTime(skill.ModifiedAt, "Never"),
)
}
return w.Flush()
}
// SkillRemoveHandler handles the skill rm command.
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
}
for _, name := range args {
n := server.ParseSkillName(name)
if n.Model == "" {
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
continue
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
continue
}
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
continue
}
if err := os.Remove(manifestPath); err != nil {
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
continue
}
// Clean up empty parent directories
dir := filepath.Dir(manifestPath)
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
entries, _ := os.ReadDir(dir)
if len(entries) == 0 {
os.Remove(dir)
dir = filepath.Dir(dir)
} else {
break
}
}
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
}
return nil
}
// SkillShowHandler handles the skill show command.
func SkillShowHandler(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
}
name := args[0]
n := server.ParseSkillName(name)
if n.Model == "" {
return fmt.Errorf("invalid skill name: %s", name)
}
displayName := n.DisplayShortest()
manifestPath, err := server.GetSkillManifestPath(n)
if err != nil {
return fmt.Errorf("getting manifest path: %w", err)
}
data, err := os.ReadFile(manifestPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("skill not found: %s", displayName)
}
return fmt.Errorf("reading manifest: %w", err)
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return fmt.Errorf("parsing manifest: %w", err)
}
fmt.Printf("Skill: %s\n\n", displayName)
fmt.Println("Layers:")
for _, layer := range manifest.Layers {
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
}
// Try to read and display SKILL.md content
if len(manifest.Layers) > 0 {
for _, layer := range manifest.Layers {
if layer.MediaType == server.MediaTypeSkill {
skillPath, err := server.GetSkillsPath(layer.Digest)
if err == nil {
skillMdPath := filepath.Join(skillPath, "SKILL.md")
if content, err := os.ReadFile(skillMdPath); err == nil {
fmt.Println("\nContent:")
fmt.Println(string(content))
}
}
}
}
}
return nil
}
// SkillInfo represents information about an installed skill.
type SkillInfo struct {
Namespace string
Name string
Tag string
Size int64
ModifiedAt time.Time
}
// listLocalSkills returns a list of locally installed skills.
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
// where kind is "skill".
func listLocalSkills() ([]SkillInfo, error) {
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
var skills []SkillInfo
// Walk through all registries
registries, err := os.ReadDir(manifestsPath)
if err != nil {
if os.IsNotExist(err) {
return skills, nil
}
return nil, err
}
for _, registry := range registries {
if !registry.IsDir() {
continue
}
// Walk namespaces
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
if err != nil {
continue
}
for _, namespace := range namespaces {
if !namespace.IsDir() {
continue
}
// Walk kinds looking for "skill"
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
if err != nil {
continue
}
for _, kind := range kinds {
if !kind.IsDir() {
continue
}
// Only process skill kind
if kind.Name() != server.SkillNamespace {
continue
}
// Walk skill names (model names)
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
if err != nil {
continue
}
for _, skillName := range skillNames {
if !skillName.IsDir() {
continue
}
// Walk tags
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
if err != nil {
continue
}
for _, tag := range tags {
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
fi, err := os.Stat(manifestPath)
if err != nil || fi.IsDir() {
continue
}
// Read manifest to get size
data, err := os.ReadFile(manifestPath)
if err != nil {
continue
}
var manifest server.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
continue
}
var totalSize int64
for _, layer := range manifest.Layers {
totalSize += layer.Size
}
// Build display name using model.Name
n := model.Name{
Host: registry.Name(),
Namespace: namespace.Name(),
Kind: kind.Name(),
Model: skillName.Name(),
Tag: tag.Name(),
}
skills = append(skills, SkillInfo{
Namespace: n.Namespace + "/" + n.Kind,
Name: n.Model,
Tag: n.Tag,
Size: totalSize,
ModifiedAt: fi.ModTime(),
})
}
}
}
}
}
return skills, nil
}
// createSkillManifest creates a manifest for a standalone skill.
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
// Read SKILL.md to extract metadata
skillMdPath := filepath.Join(skillDir, "SKILL.md")
content, err := os.ReadFile(skillMdPath)
if err != nil {
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
}
// Extract name and description from frontmatter
name, description := extractSkillMetadata(string(content))
if name == "" {
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
}
// Create config
config := map[string]any{
"name": name,
"description": description,
"architecture": "amd64",
"os": "linux",
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, nil, fmt.Errorf("marshaling config: %w", err)
}
// Create config layer
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, nil, fmt.Errorf("creating config layer: %w", err)
}
manifest := &server.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: configLayer,
Layers: []server.Layer{layer},
}
return manifest, &configLayer, nil
}
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
func extractSkillMetadata(content string) (name, description string) {
lines := strings.Split(content, "\n")
inFrontmatter := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "---" {
if !inFrontmatter {
inFrontmatter = true
continue
} else {
break // End of frontmatter
}
}
if inFrontmatter {
if strings.HasPrefix(trimmed, "name:") {
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
} else if strings.HasPrefix(trimmed, "description:") {
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
}
}
}
return name, description
}
// NewSkillCommand creates the skill parent command with subcommands.
func NewSkillCommand() *cobra.Command {
skillCmd := &cobra.Command{
Use: "skill",
Short: "Manage skills",
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
}
pushCmd := &cobra.Command{
Use: "push NAME[:TAG] PATH",
Short: "Push a skill to a registry",
Long: "Package a local skill directory and push it to a registry",
Args: cobra.ExactArgs(2),
PreRunE: checkServerHeartbeat,
RunE: SkillPushHandler,
}
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd := &cobra.Command{
Use: "pull NAME[:TAG]",
Short: "Pull a skill from a registry",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: SkillPullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List installed skills",
Args: cobra.NoArgs,
RunE: SkillListHandler,
}
rmCmd := &cobra.Command{
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
Aliases: []string{"remove", "delete"},
Short: "Remove a skill",
Args: cobra.MinimumNArgs(1),
RunE: SkillRemoveHandler,
}
showCmd := &cobra.Command{
Use: "show NAME[:TAG]",
Short: "Show skill details",
Args: cobra.ExactArgs(1),
RunE: SkillShowHandler,
}
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
return skillCmd
}

View File

@@ -1,591 +0,0 @@
package cmd
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/server"
)
const (
skillFileName = "SKILL.md"
maxSkillDescription = 1024
maxSkillNameLength = 64
)
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
type skillMetadata struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
}
type skillDefinition struct {
Name string
Description string
Content string // Full SKILL.md content (without frontmatter)
Dir string
SkillPath string
}
type skillCatalog struct {
Skills []skillDefinition
byName map[string]skillDefinition
}
func loadSkills(paths []string) (*skillCatalog, error) {
if len(paths) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, root := range paths {
info, err := os.Stat(root)
if err != nil {
return nil, fmt.Errorf("skills directory %q: %w", root, err)
}
if !info.IsDir() {
return nil, fmt.Errorf("skills path %q is not a directory", root)
}
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
// Skills can be referenced by:
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
// - Name (local path): loaded from the filesystem (for development)
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
if len(refs) == 0 {
return nil, nil
}
var skills []skillDefinition
byName := make(map[string]skillDefinition)
for _, ref := range refs {
var skillDir string
if ref.Digest != "" {
// Load from extracted skill cache
path, err := server.GetSkillsPath(ref.Digest)
if err != nil {
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
}
// Check if skill is already extracted
skillMdPath := filepath.Join(path, skillFileName)
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
// Try to extract the skill blob
path, err = server.ExtractSkillBlob(ref.Digest)
if err != nil {
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
}
}
skillDir = path
} else if ref.Name != "" {
// Check if this is a local path or a registry reference
if !server.IsLocalSkillPath(ref.Name) {
// Registry reference without a digest - skill needs to be pulled first
// This happens when an agent references a skill that hasn't been bundled
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
}
// Local path - resolve it
skillPath := ref.Name
if strings.HasPrefix(skillPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("expanding home directory: %w", err)
}
skillPath = filepath.Join(home, skillPath[1:])
}
absPath, err := filepath.Abs(skillPath)
if err != nil {
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
}
// Check if this is a directory containing skills or a single skill
info, err := os.Stat(absPath)
if err != nil {
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
}
if info.IsDir() {
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
skillMdPath := filepath.Join(absPath, skillFileName)
if _, err := os.Stat(skillMdPath); err == nil {
// Direct skill directory
skillDir = absPath
} else {
// Parent directory - walk to find skill subdirectories
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if entry.IsDir() {
return nil
}
if entry.Name() != skillFileName {
return nil
}
skillSubDir := filepath.Dir(path)
skill, err := parseSkillFile(path, skillSubDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
return nil
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
return nil
}
byName[skill.Name] = skill
skills = append(skills, skill)
return nil
})
if err != nil {
return nil, err
}
continue
}
} else {
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
}
} else {
// Both empty - skip
continue
}
// Parse the skill from skillDir if set
if skillDir != "" {
skillMdPath := filepath.Join(skillDir, skillFileName)
skill, err := parseSkillFile(skillMdPath, skillDir)
if err != nil {
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
}
if _, exists := byName[skill.Name]; exists {
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
continue
}
byName[skill.Name] = skill
skills = append(skills, skill)
}
}
if len(skills) == 0 {
return nil, nil
}
sort.Slice(skills, func(i, j int) bool {
return skills[i].Name < skills[j].Name
})
return &skillCatalog{Skills: skills, byName: byName}, nil
}
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
rawContent, err := os.ReadFile(path)
if err != nil {
return skillDefinition{}, err
}
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
if err != nil {
return skillDefinition{}, err
}
var meta skillMetadata
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
}
if err := validateSkillMetadata(meta, skillDir); err != nil {
return skillDefinition{}, err
}
absPath, err := filepath.Abs(path)
if err != nil {
return skillDefinition{}, err
}
absDir, err := filepath.Abs(skillDir)
if err != nil {
return skillDefinition{}, err
}
return skillDefinition{
Name: meta.Name,
Description: meta.Description,
Content: bodyContent,
Dir: absDir,
SkillPath: absPath,
}, nil
}
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
scanner := bufio.NewScanner(strings.NewReader(content))
if !scanner.Scan() {
return "", "", errors.New("empty SKILL.md")
}
if strings.TrimSpace(scanner.Text()) != "---" {
return "", "", errors.New("missing YAML frontmatter")
}
var fmLines []string
foundEnd := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "---" {
foundEnd = true
break
}
fmLines = append(fmLines, line)
}
if !foundEnd {
return "", "", errors.New("frontmatter not terminated")
}
// Collect remaining content as body
var bodyLines []string
for scanner.Scan() {
bodyLines = append(bodyLines, scanner.Text())
}
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
}
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
name := strings.TrimSpace(meta.Name)
description := strings.TrimSpace(meta.Description)
switch {
case name == "":
return errors.New("missing skill name")
case len(name) > maxSkillNameLength:
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
case !skillNamePattern.MatchString(name):
return fmt.Errorf("invalid skill name %q", name)
}
if description == "" {
return errors.New("missing skill description")
}
if len(description) > maxSkillDescription {
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
}
// Skip directory name check for digest-based paths (extracted from blobs)
dirName := filepath.Base(skillDir)
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
}
return nil
}
func (c *skillCatalog) SystemPrompt() string {
if c == nil || len(c.Skills) == 0 {
return ""
}
var b strings.Builder
b.WriteString("# Skills\n\n")
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
b.WriteString("## Available Tools\n\n")
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
for _, skill := range c.Skills {
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
fmt.Fprintf(&b, "%s\n\n", skill.Content)
b.WriteString("---\n\n")
}
return b.String()
}
func (c *skillCatalog) Tools() api.Tools {
if c == nil || len(c.Skills) == 0 {
return nil
}
runScriptProps := api.NewToolPropertiesMap()
runScriptProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the script",
})
runScriptProps.Set("command", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
})
readFileProps := api.NewToolPropertiesMap()
readFileProps.Set("skill", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The name of the skill containing the file",
})
readFileProps.Set("path", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The relative path to the file within the skill directory",
})
return api.Tools{
{
Type: "function",
Function: api.ToolFunction{
Name: "run_skill_script",
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "command"},
Properties: runScriptProps,
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "read_skill_file",
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"skill", "path"},
Properties: readFileProps,
},
},
},
}
}
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
switch call.Function.Name {
case "read_skill_file":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
relPath, err := requireStringArg(call.Function.Arguments, "path")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
content, err := readSkillFile(skill.Dir, relPath)
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
return toolMessage(call, content), true, nil
case "run_skill_script":
skillName, err := requireStringArg(call.Function.Arguments, "skill")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
command, err := requireStringArg(call.Function.Arguments, "command")
if err != nil {
return toolMessage(call, err.Error()), true, nil
}
skill, ok := c.byName[skillName]
if !ok {
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
}
output, err := runSkillScript(skill.Dir, command)
if err != nil {
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
}
return toolMessage(call, output), true, nil
default:
return api.Message{}, false, nil
}
}
// runSkillScript executes a shell command within a skill's directory.
//
// SECURITY LIMITATIONS (TODO):
// - No sandboxing: commands run with full user permissions
// - No path validation: model can run any command, not just scripts in skill dir
// - Shell injection risk: sh -c is used, malicious input could be crafted
// - No executable allowlist: any program can be called (curl, rm, etc.)
// - No environment isolation: scripts inherit full environment variables
//
// POTENTIAL IMPROVEMENTS:
// - Restrict commands to only reference files within skill directory
// - Allowlist specific executables (python3, node, bash)
// - Use sandboxing (Docker, nsjail, seccomp)
// - Require explicit script registration in SKILL.md frontmatter
// - Add per-skill configurable timeouts
func runSkillScript(skillDir, command string) (string, error) {
// Validate the skill directory exists
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
if _, err := os.Stat(absSkillDir); err != nil {
return "", fmt.Errorf("skill directory not found: %w", err)
}
// Create command with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = absSkillDir
// Inject the current working directory (where ollama run was called from)
// as an environment variable so scripts can reference files in that directory
workingDir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get working directory: %w", err)
}
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
// Capture both stdout and stderr
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
// Combine output
output := stdout.String()
if stderr.Len() > 0 {
if output != "" {
output += "\n"
}
output += stderr.String()
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("command timed out after 30 seconds")
}
return output, err
}
return output, nil
}
func readSkillFile(skillDir, relPath string) (string, error) {
relPath = filepath.Clean(strings.TrimSpace(relPath))
if relPath == "" {
return "", errors.New("path is required")
}
if filepath.IsAbs(relPath) {
return "", errors.New("path must be relative to the skill directory")
}
target := filepath.Join(skillDir, relPath)
absTarget, err := filepath.Abs(target)
if err != nil {
return "", err
}
absSkillDir, err := filepath.Abs(skillDir)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absSkillDir, absTarget)
if err != nil {
return "", err
}
if strings.HasPrefix(rel, "..") {
return "", errors.New("path escapes the skill directory")
}
content, err := os.ReadFile(absTarget)
if err != nil {
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
}
return string(content), nil
}
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
value, ok := args.Get(name)
if !ok {
return "", fmt.Errorf("missing required argument %q", name)
}
str, ok := value.(string)
if !ok {
return "", fmt.Errorf("argument %q must be a string", name)
}
if strings.TrimSpace(str) == "" {
return "", fmt.Errorf("argument %q cannot be empty", name)
}
return str, nil
}
func toolMessage(call api.ToolCall, content string) api.Message {
msg := api.Message{
Role: "tool",
Content: content,
ToolName: call.Function.Name,
}
if call.ID != "" {
msg.ToolCallID = call.ID
}
return msg
}

View File

@@ -6,11 +6,14 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,8 +21,13 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
} `json:"text_config"`
}
@@ -33,8 +41,94 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p AdapterParameters) KV() ggml.KV {
func (p AdapterParameters) KV() KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
alpha = p.LoraParameters.Alpha
}
kv := ggml.KV{
kv := KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelConverter interface {
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -107,7 +206,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
KV(ofs.Config) KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -115,7 +214,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
arch := baseKV.Architecture()
if arch == "" {
return errors.New("architecture not set for the base model")
}
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
return nil, nil, err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
return nil, nil, err
}
if len(p.Architectures) < 1 {
return errors.New("unknown architecture")
return nil, nil, errors.New("unknown architecture")
}
var conv ModelConverter
@@ -216,23 +311,27 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
return nil, nil, err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
return nil, nil, err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return err
return nil, nil, err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -254,6 +353,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -263,7 +375,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
func (p *bertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
func (p *commandrModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
func (m *deepseekocr) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,7 +1,5 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -9,7 +7,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,6 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,8 +3,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -55,7 +53,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
func (m *gemma3nModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -0,0 +1,150 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type glm4MoeLiteModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
}
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "glm4moelite"
kv["general.type"] = "model"
kv["glm4moelite.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["glm4moelite.attention.head_count"] = numHeads
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
kv["glm4moelite.attention.value_length"] = p.VHeadDim
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
kv["glm4moelite.embedding_length"] = p.HiddenSize
kv["glm4moelite.expert_count"] = p.ExpertCount
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
kv["glm4moelite.expert_gating_func"] = uint32(2)
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["tokenizer.ggml.pre"] = "glm4"
return kv
}
func (p *glm4MoeLiteModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
func (m *gptossModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

115
convert/convert_lfm2.go Normal file
View File

@@ -0,0 +1,115 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer head count arrays based on layer_types
headCounts := make([]uint32, p.NumHiddenLayers)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := uint32(0); i < p.NumHiddenLayers; i++ {
if i < uint32(len(p.LayerTypes)) && p.LayerTypes[i] == "full_attention" {
headCounts[i] = p.NumAttentionHeads
kvHeadCounts[i] = p.NumKeyValueHeads
} else {
// Conv layers have 0 head counts
headCounts[i] = 0
kvHeadCounts[i] = 0
}
}
kv["lfm2.attention.head_count"] = headCounts
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
// Renderer and parser config for thinking model
kv["tokenizer.chat_template.renderer"] = "lfm2-thinking"
kv["tokenizer.chat_template.parser"] = "lfm2-thinking"
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights from [D, 1, K] to [D, K]
// The shortconv.conv.weight tensors have shape [D, 1, K] in HuggingFace
// We remove the middle dimension (which is always 1) to get [D, K]
// Note: ollama's GGUF writer does NOT reverse shapes like llama.cpp's Python writer does,
// so we keep the shape in the same order as the original safetensors file
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") && len(shape) == 3 && shape[1] == 1 {
// Squeeze: [D, 1, K] -> [D, K]
shape = []uint64{shape[0], shape[2]}
// No repacker needed - data layout is already correct since the middle dim is 1
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
return data, nil
})
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
func (p *llamaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
func (p *llama4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,6 +7,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,13 +19,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
func (p *mistral3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
func (p *mixtralModel) KV(t *Tokenizer) KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
func (m *mllamaModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
func (p *olmoModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
func (p *phi3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen2Model) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen3Model) KV(t *Tokenizer) KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,6 +19,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -28,7 +29,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k, v := range kv {
for k := range kv.Keys() {
v := kv.Value(k)
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV map[string]any
BaseKV KV
Expected map[string]string
}

View File

@@ -14,6 +14,7 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
## Conventions
@@ -58,6 +59,15 @@ Advanced parameters (optional):
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
Experimental image generation parameters (for image generation models only):
> [!WARNING]
> These parameters are experimental and may change in future versions.
- `width`: width of the generated image in pixels
- `height`: height of the generated image in pixels
- `steps`: number of diffusion steps
#### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
@@ -1867,3 +1877,55 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Features
### Image Generation (Experimental)
> [!WARNING]
> Image generation is experimental and may change in future versions.
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
#### Example
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "x/z-image-turbo",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
Progress updates during generation:
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:00.000000Z",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:15.000000Z",
"image": "iVBORw0KGgoAAAANSUhEUg...",
"done": true,
"done_reason": "stop",
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -0,0 +1,408 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -275,6 +275,73 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `dimensions`
- [ ] `user`
### `/v1/images/generations` (experimental)
> Note: This endpoint is experimental and may change or be removed in future versions.
Generate images using image generation models.
<CodeGroup dropdown>
```python images.py
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
api_key='ollama', # required but ignored
)
response = client.images.generate(
model='x/z-image-turbo',
prompt='A cute robot learning to paint',
size='1024x1024',
response_format='b64_json',
)
print(response.data[0].b64_json[:50] + '...')
```
```javascript images.js
import OpenAI from "openai";
const openai = new OpenAI({
baseURL: "http://localhost:11434/v1/",
apiKey: "ollama", // required but ignored
});
const response = await openai.images.generate({
model: "x/z-image-turbo",
prompt: "A cute robot learning to paint",
size: "1024x1024",
response_format: "b64_json",
});
console.log(response.data[0].b64_json.slice(0, 50) + "...");
```
```shell images.sh
curl -X POST http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "x/z-image-turbo",
"prompt": "A cute robot learning to paint",
"size": "1024x1024",
"response_format": "b64_json"
}'
```
</CodeGroup>
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `size` (e.g. "1024x1024")
- [x] `response_format` (only `b64_json` supported)
- [ ] `n`
- [ ] `quality`
- [ ] `style`
- [ ] `user`
### `/v1/responses`
> Note: Added in Ollama v0.13.3

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch({ query: "what is ollama?" });
const results = await client.webSearch("what is ollama?");
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
const fetchResult = await client.webFetch("https://ollama.com");
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -32,7 +32,9 @@
"codeblocks": "system"
},
"contextual": {
"options": ["copy"]
"options": [
"copy"
]
},
"navbar": {
"links": [
@@ -52,7 +54,9 @@
"display": "simple"
},
"examples": {
"languages": ["curl"]
"languages": [
"curl"
]
}
},
"redirects": [
@@ -97,6 +101,7 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -106,7 +111,9 @@
"/integrations/zed",
"/integrations/roo-code",
"/integrations/n8n",
"/integrations/xcode"
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
]
},
{
@@ -139,7 +146,8 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility"
"/api/openai-compatibility",
"/api/anthropic-compatibility"
]
},
{

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

BIN
docs/images/marimo-chat.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

BIN
docs/images/onyx-login.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

BIN
docs/images/onyx-query.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -0,0 +1,78 @@
---
title: Claude Code
---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
```
2. Run Claude Code with an Ollama model:
```shell
claude --model gpt-oss:20b
```
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -0,0 +1,73 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -0,0 +1,63 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -20,8 +20,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
Start Ollama:
@@ -41,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
```
### ARM64 install
@@ -50,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -146,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
## Installing specific versions

View File

@@ -1,548 +0,0 @@
# Ollama Skills
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
## Quick Start
### Creating a Skill
Create a directory with a `SKILL.md` file:
```
my-skill/
├── SKILL.md # Required: Instructions for the agent
└── scripts/ # Optional: Executable scripts
└── run.py
```
The `SKILL.md` file must have YAML frontmatter:
```markdown
---
name: my-skill
description: A brief description of what this skill does
---
# My Skill
## Purpose
Explain what this skill does and when to use it.
## Instructions
Step-by-step instructions for the agent on how to use this skill.
## Examples
Show example inputs and expected outputs.
```
### Using Skills in an Agent
Reference skills in your Agentfile:
```dockerfile
FROM llama3.2:3b
AGENT_TYPE conversational
# Local skill (bundled with agent)
SKILL ./path/to/my-skill
# Registry skill (pulled from ollama.com)
SKILL library/skill/calculator:1.0.0
# User skill from registry
SKILL myname/skill/calculator:1.0.0
SYSTEM You are a helpful assistant.
```
### Managing Skills
```bash
# Push a skill to the registry (uses your namespace)
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
# Pull a skill from the official library
ollama skill pull skill/calculator:1.0.0
# Pull a skill from a user's namespace
ollama skill pull myname/skill/calculator:1.0.0
# List installed skills
ollama skill list
# Show skill details
ollama skill show skill/calculator:1.0.0
# Remove a skill
ollama skill rm skill/calculator:1.0.0
```
### Dynamic Skills in Chat
You can add and remove skills dynamically during an interactive chat session:
```
>>> /skills
Available Skills:
calculator (sha256:abc123def456...)
>>> /skill add ./my-local-skill
Added skill 'my-skill' from ./my-local-skill
>>> /skill list
Skills loaded in this session:
my-skill (local: /path/to/my-local-skill)
>>> /skill remove my-skill
Removed skill 'my-skill'
```
| Command | Description |
|---------|-------------|
| `/skills` | Show all available skills (model + session) |
| `/skill add <path>` | Add a skill from a local path |
| `/skill remove <name>` | Remove a skill by name |
| `/skill list` | List skills loaded in this session |
Dynamic skills take effect on the next message. This is useful for:
- Testing skills during development
- Temporarily adding capabilities to a model
- Experimenting with skill combinations
## Skill Reference Formats
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
| Format | Example | Description |
|--------|---------|-------------|
| Local path | `./skills/calc` | Bundled with agent at create time |
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
The `kind` field distinguishes skills from models:
- `skill` - Skill packages
- `agent` - Agent packages (future)
- (empty) - Regular models
## SKILL.md Structure
### Required Frontmatter
```yaml
---
name: skill-name # Must match directory name
description: Brief description of the skill
---
```
### Recommended Sections
1. **Purpose**: What the skill does and when to use it
2. **When to use**: Trigger conditions for the agent
3. **Instructions**: Step-by-step usage guide
4. **Examples**: Input/output examples
5. **Scripts**: Documentation for any bundled scripts
### Example: Calculator Skill
```markdown
---
name: calculator
description: Performs mathematical calculations using Python
---
# Calculator Skill
## Purpose
This skill performs mathematical calculations using a bundled Python script.
## When to use
- User asks to calculate something
- User wants to do math operations
- Any arithmetic is needed
## Instructions
1. When calculation is needed, use the `run_skill_script` tool
2. Call: `python3 scripts/calculate.py "<expression>"`
3. Return the result to the user
## Examples
**Input**: "What is 25 * 4?"
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
**Output**: "25 * 4 = 100"
```
## Storage Layout
```
~/.ollama/models/
├── blobs/
│ └── sha256-<digest> # Skill tar.gz blob
├── manifests/
│ └── registry.ollama.ai/
│ └── skill/ # Library skills
│ └── calculator/
│ └── 1.0.0
│ └── skill-username/ # User skills
│ └── my-skill/
│ └── latest
└── skills/
└── sha256-<digest>/ # Extracted skill cache
├── SKILL.md
└── scripts/
```
---
# Security Considerations
## Current State (Development)
The current implementation has several security considerations that need to be addressed before production use.
### 1. Script Execution
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
**Current behavior**:
- Scripts run with the same permissions as the Ollama process
- No sandboxing or isolation
- Full filesystem access
**Mitigations needed**:
- [ ] Sandbox script execution (containers, seccomp, etc.)
- [ ] Resource limits (CPU, memory, time)
- [ ] Filesystem isolation (read-only mounts, restricted paths)
- [ ] Network policy controls
- [ ] Capability dropping
### 2. Skill Provenance
**Risk**: Malicious skills could be pushed to the registry.
**Current behavior**:
- No code signing or verification
- No malware scanning
- Trust based on namespace ownership
**Mitigations needed**:
- [ ] Skill signing with author keys
- [ ] Registry-side malware scanning
- [ ] Content policy enforcement
- [ ] Reputation system for skill authors
### 3. Namespace Squatting
**Risk**: Malicious actors could register skill names that impersonate official tools.
**Current behavior**:
- First-come-first-served namespace registration
- No verification of skill names
**Mitigations needed**:
- [ ] Reserved namespace list (official tools, common names)
- [ ] Trademark/name verification for popular skills
- [ ] Clear namespacing conventions
### 4. Supply Chain Attacks
**Risk**: Compromised skills could inject malicious code into agents.
**Current behavior**:
- Skills pulled without integrity verification beyond digest
- No dependency tracking
**Mitigations needed**:
- [ ] SBOM (Software Bill of Materials) for skills
- [ ] Dependency vulnerability scanning
- [ ] Pinned versions in Agentfiles
- [ ] Audit logging of skill usage
### 5. Data Exfiltration
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
**Current behavior**:
- Skills have access to conversation context
- Scripts can make network requests
**Mitigations needed**:
- [ ] Network egress controls
- [ ] Sensitive data detection/masking
- [ ] Audit logging of script network activity
- [ ] User consent for data access
### 6. Privilege Escalation
**Risk**: Skills could escalate privileges through script execution.
**Current behavior**:
- Scripts inherit Ollama process privileges
- No capability restrictions
**Mitigations needed**:
- [ ] Run scripts as unprivileged user
- [ ] Drop all capabilities
- [ ] Mandatory access controls (SELinux/AppArmor)
## Recommended Security Model
### Skill Trust Levels
```
┌─────────────────────────────────────────────────────────────┐
│ Level 0: Untrusted (default) │
│ - No script execution │
│ - Instructions only │
│ - Safe for any skill │
├─────────────────────────────────────────────────────────────┤
│ Level 1: Sandboxed │
│ - Scripts run in isolated container │
│ - No network access │
│ - Read-only filesystem │
│ - Resource limits enforced │
├─────────────────────────────────────────────────────────────┤
│ Level 2: Trusted │
│ - Scripts run with network access │
│ - Can write to designated directories │
│ - Requires explicit user approval │
├─────────────────────────────────────────────────────────────┤
│ Level 3: Privileged (admin only) │
│ - Full host access │
│ - System administration skills │
│ - Requires admin approval │
└─────────────────────────────────────────────────────────────┘
```
### Skill Manifest Security Fields (Future)
```yaml
---
name: my-skill
description: A skill description
security:
trust_level: sandboxed
permissions:
- network:read # Can make HTTP GET requests
- filesystem:read:/data # Can read from /data
resource_limits:
max_memory: 256MB
max_cpu_time: 30s
max_disk: 100MB
signature: sha256:abc... # Author signature
---
```
---
# Future Considerations
## Feature Roadmap
### Phase 1: Foundation (Current)
- [x] Skill bundling with agents
- [x] Local skill development
- [x] Basic CLI commands (push, pull, list, rm, show)
- [x] Registry blob storage
- [ ] Registry namespace configuration
### Phase 2: Security
- [ ] Script sandboxing
- [ ] Permission model
- [ ] Skill signing
- [ ] Audit logging
### Phase 3: Discovery
- [ ] Skill search on ollama.com
- [ ] Skill ratings and reviews
- [ ] Usage analytics
- [ ] Featured/trending skills
### Phase 4: Advanced Features
- [ ] Skill dependencies
- [ ] Skill versioning constraints
- [ ] Skill composition (skills using skills)
- [ ] Skill testing framework
## Open Questions
### 1. Skill Execution Model
**Question**: How should skills execute scripts?
Options:
- **A) In-process**: Fast but unsafe
- **B) Subprocess**: Current approach, moderate isolation
- **C) Container**: Good isolation, requires container runtime
- **D) WASM**: Portable and safe, limited capabilities
- **E) Remote execution**: Offload to secure service
### 2. Skill Versioning
**Question**: How strict should version pinning be?
Options:
- **A) Always latest**: Simple but risky
- **B) Semantic versioning**: `^1.0.0` allows minor updates
- **C) Exact pinning**: `=1.0.0` requires explicit updates
- **D) Digest pinning**: `@sha256:abc` immutable reference
### 3. Skill Permissions
**Question**: How should users grant permissions to skills?
Options:
- **A) All or nothing**: Accept all permissions or don't use
- **B) Granular consent**: Approve each permission individually
- **C) Trust levels**: Pre-defined permission bundles
- **D) Runtime prompts**: Ask when permission is first used
### 4. Skill Discovery
**Question**: How should users find skills?
Options:
- **A) Central registry only**: ollama.com/skills
- **B) Federated registries**: Multiple skill sources
- **C) Git repositories**: Pull from GitHub, etc.
- **D) All of the above**: Multiple discovery mechanisms
### 5. Skill Monetization
**Question**: Should skill authors be able to monetize?
Options:
- **A) Free only**: All skills are free and open
- **B) Paid skills**: Authors can charge for skills
- **C) Freemium**: Free tier with paid features
- **D) Donations**: Voluntary support for authors
### 6. Skill Updates
**Question**: How should skill updates be handled?
Options:
- **A) Manual**: User explicitly updates
- **B) Auto-update**: Always use latest
- **C) Notify**: Alert user to available updates
- **D) Policy-based**: Organization controls update policy
## API Considerations
### Skill Metadata API
```
GET /api/skills
GET /api/skills/:namespace/:name
GET /api/skills/:namespace/:name/versions
GET /api/skills/:namespace/:name/readme
```
### Skill Execution API
```
POST /api/skills/:namespace/:name/execute
{
"command": "python3 scripts/run.py",
"args": ["--input", "data"],
"timeout": 30
}
```
### Skill Permissions API
```
GET /api/skills/:namespace/:name/permissions
POST /api/skills/:namespace/:name/permissions/grant
DELETE /api/skills/:namespace/:name/permissions/revoke
```
## Testing Considerations
### Skill Testing Framework
```bash
# Run skill tests
ollama skill test ./my-skill
# Test with specific model
ollama skill test ./my-skill --model llama3.2:3b
# Generate test report
ollama skill test ./my-skill --report
```
### Test File Format
```yaml
# my-skill/tests/test.yaml
tests:
- name: "basic calculation"
input: "What is 2 + 2?"
expect:
contains: "4"
tool_called: "run_skill_script"
- name: "complex expression"
input: "Calculate 15% of 200"
expect:
contains: "30"
```
## Compatibility Considerations
### Minimum Ollama Version
Skills should declare minimum Ollama version:
```yaml
---
name: my-skill
requires:
ollama: ">=0.4.0"
---
```
### Model Compatibility
Skills may require specific model capabilities:
```yaml
---
name: vision-skill
requires:
capabilities:
- vision
- tools
---
```
## Migration Path
### From Local to Registry
```bash
# Develop locally
SKILL ./my-skill
# Push when ready
ollama skill push myname/my-skill:1.0.0 ./my-skill
# Update Agentfile
SKILL skill/myname/my-skill:1.0.0
```
### Version Upgrades
```bash
# Check for updates
ollama skill outdated
# Update specific skill
ollama skill update calculator:1.0.0
# Update all skills
ollama skill update --all
```

View File

@@ -1,3 +0,0 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -148,16 +148,6 @@ func Remotes() []string {
return r
}
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
// Returns empty slice if not configured.
func Skills() []string {
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
if raw == "" {
return []string{}
}
return strings.Split(raw, ",")
}
func BoolWithDefault(k string) func(defaultValue bool) bool {
return func(defaultValue bool) bool {
if s := Var(k); s != "" {
@@ -327,9 +317,6 @@ func AsMap() map[string]EnvVar {
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
}
// Skills configuration would go here when added
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
return ret
}

View File

@@ -1,5 +1,7 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -11,4 +13,8 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -239,6 +241,18 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",
@@ -255,6 +269,8 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}
@@ -842,6 +858,7 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"glm4moelite",
"gptoss", "gpt-oss",
"mistral3",
"olmo3",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
return err
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
return err
}
}

2
go.mod
View File

@@ -87,5 +87,5 @@ require (
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,174 @@
//go:build integration
package integration
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, testEndpoint, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, testEndpoint, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the OpenAI-compatible image generation API and returns the base64 image data
func generateImage(ctx context.Context, endpoint, model, prompt string) (string, error) {
reqBody := imagegenapi.ImageGenerationRequest{
Model: model,
Prompt: prompt,
N: 1,
Size: "512x512",
ResponseFormat: "b64_json",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("http://%s/v1/images/generations", endpoint)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
var buf bytes.Buffer
buf.ReadFrom(resp.Body)
return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, buf.String())
}
var genResp imagegenapi.ImageGenerationResponse
if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(genResp.Data) == 0 {
return "", fmt.Errorf("no image data in response")
}
return genResp.Data[0].B64JSON, nil
}

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// DoneReason represents the reason why a completion response is done
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

152
middleware/anthropic.go Normal file
View File

@@ -0,0 +1,152 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var errData struct {
Error string `json:"error"`
}
if err := json.Unmarshal(data, &errData); err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -0,0 +1,607 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -8,6 +8,7 @@ import (
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
stream bool
responseID string
itemID string
request openai.ResponsesRequest
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
request: req,
}
// Set headers based on streaming mode
@@ -541,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
c.Next()
}
}
type ImageWriter struct {
BaseWriter
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
if err := json.Unmarshal(data, &generateResponse); err != nil {
return 0, err
}
// Only write response when done with image
if generateResponse.Done && generateResponse.Image != "" {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
}
}
}
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "image generation basic",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
},
},
{
name: "image generation with size",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
Width: 512,
Height: 768,
},
},
{
name: "image generation missing prompt",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image generation missing model",
body: `{
"prompt": "a beautiful sunset"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}
func TestImageWriterResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
// Test that ImageWriter transforms GenerateResponse to OpenAI format
endpoint := func(c *gin.Context) {
resp := api.GenerateResponse{
Model: "test-model",
CreatedAt: time.Unix(1234567890, 0).UTC(),
Done: true,
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
}
data, _ := json.Marshal(resp)
c.Writer.Write(append(data, '\n'))
}
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.Handle(http.MethodPost, "/api/generate", endpoint)
body := `{"model": "test-model", "prompt": "test"}`
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var imageResp openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if imageResp.Created != 1234567890 {
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
}
if len(imageResp.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
}
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}

View File

@@ -162,6 +162,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -0,0 +1,304 @@
package glm4moelite
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads int
eps,
ropeBase float32
kqScale float64
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
compressedKV.Stride(1), 1,
compressedKV.Stride(1), compressedKV.Dim(1),
)
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "glm4":
pre = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glm4moelite", New)
}

358
model/models/lfm2/cache.go Normal file
View File

@@ -0,0 +1,358 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
}
c.curSlots = append(c.curSlots, slot)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
ctx.Forward(buf.SetRows(ctx, src, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
ctx.Forward(buf.SetRows(ctx, src, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

284
model/models/lfm2/model.go Normal file
View File

@@ -0,0 +1,284 @@
package lfm2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeType string
originalContextLength int
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
}
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
if h > 0 {
return cmp.Or(o.headDim, o.hiddenSize/h)
}
}
return cmp.Or(o.headDim, o.hiddenSize)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
headCount := 1
for _, h := range o.numHeadsByLayer {
if h > 0 {
headCount = h
break
}
}
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
// TODO: support mixtures of experts
return nil, model.ErrUnsupportedModel
}
// Tokenizer
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
// LFM2 uses a llama3-style BPE pretokenizer.
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "lfm2", "llama3", "llama-v3", "llama-bpe":
pretokenizers = []string{
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "qwen2":
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "refact":
pretokenizers = []string{
`\p{N}`,
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
}
case "tekken":
pretokenizers = []string{
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "default":
// no-op use the default bpe pretokenizer
default:
// use a llama-style pretokenizer
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
return nil, fmt.Errorf("unsupported tokenizer: llama")
default:
return nil, model.ErrUnsupportedTokenizer
}
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
hc, ok := c.(headCounts)
if !ok {
return nil, model.ErrUnsupportedModel
}
headCount := hc.HeadCount()
headCountKV := hc.HeadCountKV()
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
if m.numKVHeadsByLayer[i] == 0 {
m.Layers[i].Operator = &ShortConv{}
} else {
m.Layers[i].Operator = &Attention{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
dConv := max(0, lCache-1)
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
return &m, nil
}
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := opts.headDimValue()
numHeads := opts.numHeadsByLayer[layer]
numKVHeads := opts.numKVHeadsByLayer[layer]
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, numHeads, batchSize)
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("lfm2", New)
}

View File

@@ -0,0 +1,52 @@
package lfm2
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type shortConvKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
// state stored in the HybridCache.
type ShortConv struct {
Conv *shortConvKernel `gguf:"shortconv.conv"`
InProj *nn.Linear `gguf:"shortconv.in_proj"`
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
}
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
nSeqs := cache.numSeqs()
seqTokens := cache.seqTokens()
hiddenSize := hiddenStates.Dim(0)
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
panic("lfm2: unsupported batch layout for shortconv")
}
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
elementSize := bcx.Stride(0)
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
state, err := cache.ConvState(ctx, layer)
if err != nil {
panic("lfm2: failed to get conv state: " + err.Error())
}
sx := state.Concat(ctx, bx, 0)
// Cast weight to F32 for SSMConv (Metal requires F32)
weightF32 := sc.Conv.Weight.Cast(ctx, ml.DTypeF32)
convOut := sx.SSMConv(ctx, weightF32)
y := c.Mul(ctx, convOut)
dConv := sx.Dim(0) - seqTokens
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
}

View File

@@ -7,7 +7,9 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

410
model/parsers/glm46.go Normal file
View File

@@ -0,0 +1,410 @@
package parsers
import (
"context"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type glm46ParserState int
const (
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
glm46ParserState_ThinkingStartedEatingWhitespace
glm46ParserState_CollectingThinking
glm46ParserState_ThinkingDoneEatingWhitespace
glm46ParserState_CollectingContent
glm46ParserState_ToolStartedEatingWhitespace
glm46ParserState_CollectingToolContent
)
const (
glm46ThinkingOpenTag = "<think>"
glm46ThinkingCloseTag = "</think>"
glm46ToolOpenTag = "<tool_call>"
glm46ToolCloseTag = "</tool_call>"
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}
type glm46Event interface {
isGLM46Event()
}
type glm46EventContent struct {
content string
}
func (glm46EventContent) isGLM46Event() {}
type glm46EventRawToolCall struct {
raw string
}
func (glm46EventRawToolCall) isGLM46Event() {}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseGLM46ToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46ParserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = glm46ParserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case glm46ParserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
case glm46ParserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, glm46ThinkingCloseTag) {
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, glm46EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
} else {
p.state = glm46ParserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
case glm46ParserState_CollectingContent:
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
if after == "" {
p.state = glm46ParserState_ToolStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
case glm46ParserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, glm46ToolCloseTag) {
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm46 tool call closing tag found but no content before it")
}
events = append(events, glm46EventRawToolCall{raw: toolContent})
p.state = glm46ParserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
type GLMToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeGLM46Content(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
// We need to escape text between tags, but not the tags themselves
escaped := escapeGLM46Content(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}

862
model/parsers/glm46_test.go Normal file
View File

@@ -0,0 +1,862 @@
package parsers
import (
"encoding/xml"
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM46ParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []glm46Event
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "leading whitespace before think tag",
steps: []step{
{
input: " \n\t ",
wantEvents: []glm46Event{},
},
{
input: "<think>thinking</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
},
},
},
{
desc: "think tag with whitespace inside",
steps: []step{
{
input: "<think> \n thinking content \n </think>regular content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
glm46EventContent{content: "regular content"},
},
},
},
},
{
desc: "tool call with leading whitespace after opening tag",
steps: []step{
{
input: "<think></think><tool_call> \n test \n </tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "simple thinking then content",
steps: []step{
{
input: "<think>I am thinking</think>Now I respond",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "I am thinking"},
glm46EventContent{content: "Now I respond"},
},
},
},
},
{
desc: "streamed thinking content",
steps: []step{
{
input: "<think>hello",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
},
{
input: " world",
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "<think>Let me call a tool</think>here is text<tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "Let me call a tool"},
glm46EventContent{content: "here is text"},
},
},
{
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
},
},
},
},
{
desc: "tool call with content after",
steps: []step{
{
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after tool"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call is trimmed",
steps: []step{
{
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content is trimmed",
steps: []step{
{
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking close tag",
steps: []step{
{
input: "<think>thinking content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
},
{
input: "ink>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking open tag",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nk>content</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
},
},
{
desc: "split tool open tag",
steps: []step{
{
input: "<think>think</think>content<tool",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
},
{
input: "_call>inside",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "inside"},
},
},
},
},
{
desc: "partial thinking close tag fakeout",
steps: []step{
{
input: "<think>content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
{
input: "ought more",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
},
},
},
{
desc: "partial thinking open tag fakeout",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nking is fun",
wantEvents: []glm46Event{
glm46EventContent{content: " <thinking is fun"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "<think></think>content\n<tool",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
{
input: " fakeout",
wantEvents: []glm46Event{
glm46EventContent{content: "\n<tool fakeout"},
},
},
},
},
{
desc: "partial tool close tag fakeout",
steps: []step{
{
input: "<think></think><tool_call>content</tool",
wantEvents: []glm46Event{},
},
{
input: " fakeout",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "content</tool fakeout"},
},
},
},
},
{
desc: "empty thinking tag",
steps: []step{
{
input: "<think></think>content here",
wantEvents: []glm46Event{
glm46EventContent{content: "content here"},
},
},
},
},
{
desc: "multiple tool calls in sequence",
steps: []step{
{
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "first"},
glm46EventContent{content: "between"},
glm46EventRawToolCall{raw: "second"},
glm46EventContent{content: "end"},
},
},
},
},
{
desc: "no thinking tag - direct to content",
steps: []step{
{
input: "just content here",
wantEvents: []glm46Event{
glm46EventContent{content: "just content here"},
},
},
},
},
{
desc: "no thinking tag - skip to content then tool call",
steps: []step{
{
input: "Here's the answer:<tool_call>test</tool_call>done",
wantEvents: []glm46Event{
glm46EventContent{content: "Here's the answer:"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "done"},
},
},
},
},
{
desc: "no thinking tag - whitespace preserved when no tags",
steps: []step{
{
input: " \n content with leading whitespace",
wantEvents: []glm46Event{
glm46EventContent{content: " \n content with leading whitespace"},
},
},
},
},
{
desc: "whitespace after think close tag gets eaten",
steps: []step{
{
input: "<think>thinking</think> \n\t content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "whitespace after tool_call close tag gets eaten",
steps: []step{
{
input: "<think></think><tool_call>test</tool_call> \n\t content",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace (single chunk)",
steps: []step{
{
input: "<think>thinking content ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
},
},
{
input: "</think>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace with newlines",
steps: []step{
{
input: "<think>thinking\n\n ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content trailing whitespace emitted when more content arrives",
steps: []step{
{
input: "<think>thinking ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "more thinking",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: " more thinking"},
},
},
{
input: "</think>",
wantEvents: []glm46Event{},
},
},
},
{
desc: "thinking content withholds trailing whitespace before partial close tag",
steps: []step{
{
input: "<think>thinking </th",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "ink>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := GLM46Parser{}
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
// document order when collecting multiple elements with the same tag name into slices.
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
testCases := []struct {
name string
xml string
wantKeys []string
wantValues []string
}{
{
name: "alternating keys and values",
xml: `<tool_call>
function_name
<arg_key>first</arg_key>
<arg_value>A</arg_value>
<arg_key>second</arg_key>
<arg_value>B</arg_value>
<arg_key>third</arg_key>
<arg_value>C</arg_value>
</tool_call>`,
wantKeys: []string{"first", "second", "third"},
wantValues: []string{"A", "B", "C"},
},
{
name: "all keys then all values",
xml: `<tool_call>
function_name
<arg_key>key1</arg_key>
<arg_key>key2</arg_key>
<arg_key>key3</arg_key>
<arg_value>val1</arg_value>
<arg_value>val2</arg_value>
<arg_value>val3</arg_value>
</tool_call>`,
wantKeys: []string{"key1", "key2", "key3"},
wantValues: []string{"val1", "val2", "val3"},
},
{
name: "mixed grouping",
xml: `<tool_call>
function_name
<arg_key>a</arg_key>
<arg_value>1</arg_value>
<arg_key>b</arg_key>
<arg_key>c</arg_key>
<arg_value>2</arg_value>
<arg_value>3</arg_value>
</tool_call>`,
wantKeys: []string{"a", "b", "c"},
wantValues: []string{"1", "2", "3"},
},
{
name: "reverse order - all values then all keys",
xml: `<tool_call>
function_name
<arg_value>X</arg_value>
<arg_value>Y</arg_value>
<arg_value>Z</arg_value>
<arg_key>x</arg_key>
<arg_key>y</arg_key>
<arg_key>z</arg_key>
</tool_call>`,
wantKeys: []string{"x", "y", "z"},
wantValues: []string{"X", "Y", "Z"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var parsed GLMToolCallXML
err := xml.Unmarshal([]byte(tc.xml), &parsed)
if err != nil {
t.Fatalf("failed to unmarshal XML: %v", err)
}
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
}
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
}
})
}
}
func TestGLM46ToolCallParsing(t *testing.T) {
type testCase struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
cases := []testCase{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `get-current-weather
<arg_key>location</arg_key>
<arg_value>New York, NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `calculate
<arg_key>x</arg_key>
<arg_value>3.14</arg_value>
<arg_key>y</arg_key>
<arg_value>42</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
<arg_key>items</arg_key>
<arg_value>["a", "b", "c"]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
},
},
},
{
name: "function name with whitespace",
tools: []api.Tool{},
rawToolCall: ` get-weather
<arg_key>city</arg_key>
<arg_value>Paris</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris"}`),
},
},
},
{
name: "values with special characters",
tools: []api.Tool{},
rawToolCall: `execute-command
<arg_key>command</arg_key>
<arg_value>ls && echo "done"</arg_value>
<arg_key>message</arg_key>
<arg_value>a < b and c > d</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute-command",
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
},
},
},
{
name: "unicode in function names and values",
tools: []api.Tool{},
rawToolCall: `获取天气
<arg_key>城市</arg_key>
<arg_value>北京</arg_value>
<arg_key>message</arg_key>
<arg_value>Hello! 你好! 🌟</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
},
},
},
{
name: "empty value",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param1</arg_key>
<arg_value></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param1": ""}`),
},
},
},
{
name: "special chars in arg_key names",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param<1></arg_key>
<arg_value>value1</arg_value>
<arg_key>a&b</arg_key>
<arg_value>value2</arg_value>
<arg_key>x>y</arg_key>
<arg_value>value3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
},
},
},
{
name: "multiple consecutive ampersands",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>test &&&& more</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "test &&&& more"}`),
},
},
},
{
name: "mixed special chars together",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value><>&<>&</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "<>&<>&"}`),
},
},
},
{
name: "newlines and tabs in parameter values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>multiline</arg_key>
<arg_value>line1
indented line2
line3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
},
},
},
{
name: "single and double quotes in values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>quotes</arg_key>
<arg_value>She said "Hello's there!"</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
},
},
},
{
name: "CDATA-like content that should be treated as text",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>cdata</arg_key>
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
},
},
},
{
name: "all special XML entities",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>entities</arg_key>
<arg_value>&lt;&gt;&amp;&apos;&quot;</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"entities": "&lt;&gt;&amp;&apos;&quot;"}`),
},
},
},
{
name: "order preservation with multiple parameters",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>first</arg_key>
<arg_value>value1</arg_value>
<arg_key>second</arg_key>
<arg_value>value2</arg_value>
<arg_key>third</arg_key>
<arg_value>value3</arg_value>
<arg_key>fourth</arg_key>
<arg_value>value4</arg_value>
<arg_key>fifth</arg_key>
<arg_value>value5</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
},
},
},
{
name: "order preservation with identical key names but different positions",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>first occurrence</arg_value>
<arg_key>other</arg_key>
<arg_value>middle</arg_value>
<arg_key>param</arg_key>
<arg_value>second occurrence</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
// Later occurrence should overwrite earlier one
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
},
},
},
{
name: "array with mixed types",
tools: []api.Tool{
tool("process", map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `process
<arg_key>items</arg_key>
<arg_value>[1, "hello", true, null]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: args(`{"items": [1, "hello", true, null]}`),
},
},
},
{
name: "empty array",
tools: []api.Tool{
tool("test", map[string]api.ToolProperty{
"tags": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `test
<arg_key>tags</arg_key>
<arg_value>[]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test",
Arguments: args(`{"tags": []}`),
},
},
},
{
name: "anyOf array or string - with array of objects",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
}),
},
// <tool_call>TodoWrite
// <arg_key>todos</arg_key>
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
// </tool_call>
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
},
},
},
{
name: "anyOf array or string - with plain string",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {Type: api.PropertyType{"array", "string"}},
}),
},
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>Error: could not load todos</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": "Error: could not load todos"}`),
},
},
},
}
for i, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
if err != nil {
t.Errorf("case %d (%s): %v", i, tc.name, err)
}
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
}
})
}
}

20
model/parsers/glm47.go Normal file
View File

@@ -0,0 +1,20 @@
package parsers
import "github.com/ollama/ollama/api"
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type GLM47Parser struct {
GLM46Parser
}
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = glm46ParserState_CollectingThinking
}
return tools
}

View File

@@ -0,0 +1,99 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM47ParserAdd(t *testing.T) {
parser := GLM47Parser{}
parser.Init([]api.Tool{
tool("calculate", map[string]api.ToolProperty{
"count": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
}),
}, nil, nil)
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
// so the model output does NOT include the opening <think> tag.
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "plan" {
t.Fatalf("expected thinking 'plan', got %q", thinking)
}
if content != "Answer" {
t.Fatalf("expected content 'Answer', got %q", content)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
expectedArgs := args(`{"count": 3, "enabled": true}`)
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
}
}
func TestGLM47ParserNoThinkingContent(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
// When thinking is enabled but model has no thinking to output,
// it should output </think> immediately followed by content.
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserThinkingDisabled(t *testing.T) {
parser := GLM47Parser{}
// When thinking is disabled, parser stays in LookingForThinkingOpen state
parser.Init(nil, nil, &api.ThinkValue{Value: false})
// Model outputs plain content (prompt ended with </think>)
content, thinking, calls, err := parser.Add("Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserToolCallEscaping(t *testing.T) {
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
<arg_key>expr</arg_key>
<arg_value>a < b && c > d</arg_value>`}, nil)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
expected := api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: args(`{"expr": "a < b && c > d"}`),
},
}
if !reflect.DeepEqual(toolCall, expected) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}

393
model/parsers/lfm2.go Normal file
View File

@@ -0,0 +1,393 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type LFM2ParserState int
const (
LFM2CollectingThinking LFM2ParserState = iota
LFM2CollectingContent
LFM2CollectingToolCalls
)
const (
lfm2ThinkingOpenTag = "<think>"
lfm2ThinkingCloseTag = "</think>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
)
type LFM2Parser struct {
state LFM2ParserState
buffer strings.Builder
hasThinkingSupport bool
}
func (p *LFM2Parser) HasToolSupport() bool {
return true
}
func (p *LFM2Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !thinkingEnabled {
p.state = LFM2CollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = LFM2CollectingContent
return
}
p.state = LFM2CollectingThinking
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, thinkValue)
return tools
}
type lfm2Event interface {
isLFM2Event()
}
type lfm2EventThinkingContent struct {
content string
}
type lfm2EventContent struct {
content string
}
type lfm2EventToolCall struct {
toolCall api.ToolCall
}
func (lfm2EventThinkingContent) isLFM2Event() {}
func (lfm2EventContent) isLFM2Event() {}
func (lfm2EventToolCall) isLFM2Event() {}
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case lfm2EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case lfm2EventThinkingContent:
thinkingSb.WriteString(event.content)
case lfm2EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
keepLooping := true
for keepLooping {
var events []lfm2Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
var events []lfm2Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case LFM2CollectingThinking:
// Strip opening <think> tag if present
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingContent
if len(thinking) > 0 {
events = append(events, lfm2EventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
}
case LFM2CollectingContent:
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, lfm2EventContent{content: contentBefore})
}
return events, true
} else { // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, lfm2EventContent{content: bufStr})
}
return events, false
}
case LFM2CollectingToolCalls:
// Look for complete tool call JSON between tags
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
toolCallContent := bufStr[:idx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
// Check if there's another tool call
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
remaining = remaining[len(lfm2ToolCallStartTag):]
} else {
// No more tool calls, go back to content
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.state = LFM2CollectingContent
}
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, lfm2EventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
}
}
return events, false
}
return events, false
}
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return api.ToolCall{}, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}, nil
}
// Try Python-style format: [func(arg1='val1', arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCall(content)
}
// parsePythonStyleToolCall parses tool calls in Python function call syntax
// Examples: [bash(command='ls')] or bash(command='ls', flag='-la')
func (p *LFM2Parser) parsePythonStyleToolCall(content string) (api.ToolCall, error) {
content = strings.TrimSpace(content)
// Strip outer brackets if present: [func(...)] -> func(...)
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
content = content[1 : len(content)-1]
}
// Find function name and arguments: func(args)
parenIdx := strings.Index(content, "(")
if parenIdx == -1 {
return api.ToolCall{}, errors.New("invalid tool call: no opening parenthesis")
}
funcName := strings.TrimSpace(content[:parenIdx])
if funcName == "" {
return api.ToolCall{}, errors.New("invalid tool call: empty function name")
}
// Extract arguments between parentheses
if !strings.HasSuffix(content, ")") {
return api.ToolCall{}, errors.New("invalid tool call: no closing parenthesis")
}
argsStr := content[parenIdx+1 : len(content)-1]
args := api.NewToolCallFunctionArguments()
if argsStr != "" {
// Parse key='value' or key="value" pairs
if err := parsePythonArgs(argsStr, &args); err != nil {
return api.ToolCall{}, err
}
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
}, nil
}
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
}
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
}
return nil
}

549
model/parsers/lfm2_test.go Normal file
View File

@@ -0,0 +1,549 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "simple_content",
input: "Hello, how are you?",
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "thinking_content",
input: "I need to think about this...</think>The answer is 42.",
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "no_thinking_simple",
input: "Just a regular response.",
expectedContent: "Just a regular response.",
hasThinking: false,
},
{
name: "thinking_with_newlines",
input: "Let me think:\n- Point 1\n- Point 2</think>\n\nHere's my answer.",
expectedThinking: "Let me think:\n- Point 1\n- Point 2",
expectedContent: "Here's my answer.",
hasThinking: true,
},
{
name: "tool_call_simple",
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
expectedContent: "I'll check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: false,
},
{
name: "multiple_tool_calls",
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
expectedContent: "Getting weather for both cities.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
hasThinking: false,
},
{
name: "complex_tool_arguments",
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
"items": []interface{}{"item1", "item2"},
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
}),
},
},
},
hasThinking: false,
},
{
name: "thinking_with_tool_call",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
expectedThinking: "Let me check the weather...",
expectedContent: "I'll get that for you.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: true,
},
{
name: "empty_content",
input: "",
expectedContent: "",
hasThinking: false,
},
{
name: "only_thinking",
input: "Just thinking content</think>",
expectedThinking: "Just thinking content",
expectedContent: "",
hasThinking: true,
},
{
name: "unicode_content",
input: "مرحبا بالعالم! 你好世界! 🌍",
expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
hasThinking: false,
},
{
name: "emoji_passthrough",
input: "Task completed ✅ 🎉",
expectedContent: "Task completed ✅ 🎉",
hasThinking: false,
},
{
name: "newlines_and_whitespace",
input: "Line 1\n\nLine 3\t\tTabbed content",
expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
hasThinking: false,
},
{
name: "thinking_with_unicode",
input: "我在思考这个问题...</think>答案是42。",
expectedThinking: "我在思考这个问题...",
expectedContent: "答案是42。",
hasThinking: true,
},
{
name: "tool_call_with_unicode_args",
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
expectedContent: "Searching for information.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
"query": "北京天气",
"language": "中文",
}),
},
},
},
hasThinking: false,
},
{
name: "thinking_with_special_chars",
input: "Let me calculate: 2+2=4 & 3*3=9...</think>The results are correct!",
expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
expectedContent: "The results are correct!",
hasThinking: true,
},
{
name: "empty_tool_call_args",
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
expectedContent: "Pinging server.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, calls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "streaming_simple_content",
chunks: []string{"Hello, ", "how are ", "you?"},
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "streaming_thinking",
chunks: []string{"I need to ", "think about this", "...</think>", "The answer is 42."},
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "streaming_tool_call",
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
expectedContent: "I'll check weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: false,
},
{
name: "streaming_thinking_with_partial_tag",
chunks: []string{"Thinking about this", "...</", "think>", "Done thinking."},
expectedThinking: "Thinking about this...",
expectedContent: "Done thinking.",
hasThinking: true,
},
{
name: "streaming_unicode_content",
chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
expectedContent: "مرحبا بالعالم! 你好世界!",
hasThinking: false,
},
{
name: "streaming_tool_call_with_split_json",
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
expectedContent: "Processing.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: testArgs(map[string]any{
"x": float64(42),
"y": float64(24),
}),
},
},
},
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
var allContent, allThinking string
var allCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, calls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
}
if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_HasThinkingSupport(t *testing.T) {
tests := []struct {
name string
hasThinking bool
expectedSupport bool
}{
{
name: "thinking_enabled",
hasThinking: true,
expectedSupport: true,
},
{
name: "thinking_disabled",
hasThinking: false,
expectedSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
}
})
}
}
func TestLFM2Parser_HasToolSupport(t *testing.T) {
parser := &LFM2Parser{}
if !parser.HasToolSupport() {
t.Error("HasToolSupport() should return true")
}
}
func TestLFM2Parser_Init(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: true}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test_tool",
},
},
}
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
}
// Test initial state is set to thinking when enabled
if parser.state != LFM2CollectingThinking {
t.Errorf("Expected initial state to be LFM2CollectingThinking, got %v", parser.state)
}
}
func TestLFM2Parser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call",
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
{
name: "complex_arguments",
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
}),
},
},
},
{
name: "empty_arguments",
content: `{"name":"ping","arguments":{}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
{
name: "unicode_in_tool_name",
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
"城市": "北京",
}),
},
},
},
{
name: "numeric_arguments",
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
}),
},
},
},
{
name: "invalid_json",
content: `{invalid json}`,
expectError: true,
},
{
name: "missing_name",
content: `{"arguments":{"arg":"value"}}`,
expectError: true,
},
{
name: "empty_name",
content: `{"name":"","arguments":{"arg":"value"}}`,
expectError: true,
},
}
parser := &LFM2Parser{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_EdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
hasThinking bool
}{
{
name: "multiple_think_close_tags",
input: "First thought</think>Second thought</think>Final content",
expectedThinking: "First thought",
expectedContent: "Second thought</think>Final content",
hasThinking: true,
},
{
name: "empty_thinking_content",
input: "</think>Just content",
expectedThinking: "",
expectedContent: "Just content",
hasThinking: true,
},
{
name: "thinking_disabled_with_think_tags",
input: "Some content</think>More content",
expectedContent: "Some content</think>More content",
hasThinking: false,
},
{
name: "whitespace_only_content",
input: " \n\t ",
expectedContent: " \n\t ",
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, _, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -1,7 +1,6 @@
package parsers
import (
"regexp"
"strings"
"unicode"
@@ -14,243 +13,114 @@ const (
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
Nemotron3NanoSkipWhitespaceAfterThinking
Nemotron3NanoCollectingContent
Nemotron3NanoCollectingToolCalls
)
const (
nemotronThinkClose = "</think>"
nemotronToolCallOpen = "<tool_call>"
nemotronToolCallClose = "</tool_call>"
nemotronThinkClose = "</think>"
nemotronToolCallOpen = "<tool_call>"
)
type Nemotron3NanoParser struct {
state Nemotron3NanoParserState
buffer strings.Builder
tools []api.Tool
state Nemotron3NanoParserState
buffer strings.Builder
toolParser *Qwen3CoderParser
}
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.toolParser = &Qwen3CoderParser{}
p.toolParser.Init(tools, nil, nil)
// thinking is enabled if user requests it
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !thinkingEnabled {
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
p.state = Nemotron3NanoCollectingContent
return tools
} else {
p.state = Nemotron3NanoCollectingThinking
}
if prefill && lastMessage.Content != "" {
p.state = Nemotron3NanoCollectingContent
return tools
}
p.state = Nemotron3NanoCollectingThinking
return tools
}
type nemotronEvent interface {
isNemotronEvent()
}
type nemotronEventThinkingContent struct {
content string
}
type nemotronEventContent struct {
content string
}
type nemotronEventToolCall struct {
toolCall api.ToolCall
}
func (nemotronEventThinkingContent) isNemotronEvent() {}
func (nemotronEventContent) isNemotronEvent() {}
func (nemotronEventToolCall) isNemotronEvent() {}
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case nemotronEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case nemotronEventThinkingContent:
thinkingSb.WriteString(event.content)
case nemotronEventContent:
contentSb.WriteString(event.content)
}
if p.state == Nemotron3NanoCollectingContent {
return p.toolParser.Add(s, done)
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
var all []nemotronEvent
keepLooping := true
for keepLooping {
var events []nemotronEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
}
wsLen := trailingWhitespaceLen(bufStr)
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
}
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
bufStr := p.buffer.String()
if bufStr == "" {
return nil, false
}
switch p.state {
case Nemotron3NanoCollectingThinking:
if strings.Contains(bufStr, nemotronThinkClose) {
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
p.buffer.Reset()
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.WriteString(remainder)
// Transition to whitespace-skipping state if buffer is empty,
// otherwise go directly to content collection
if remainder == "" {
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
} else {
p.state = Nemotron3NanoCollectingContent
}
if thinking != "" {
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
}
return nil, true
}
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
p.buffer.Reset()
p.buffer.WriteString(ambig)
if unambig != "" {
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
}
return nil, false
// We only want to skip whitespace between thinking and content
case Nemotron3NanoSkipWhitespaceAfterThinking:
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(bufStr)
if bufStr == "" {
return nil, false
if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
s = strings.TrimLeftFunc(s, unicode.IsSpace)
if s == "" {
return "", "", nil, nil
}
p.state = Nemotron3NanoCollectingContent
return nil, true
return p.toolParser.Add(s, done)
}
case Nemotron3NanoCollectingContent:
if strings.Contains(bufStr, nemotronToolCallOpen) {
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(split[1])
p.state = Nemotron3NanoCollectingToolCalls
if content != "" {
return []nemotronEvent{nemotronEventContent{content: content}}, true
}
return nil, true
}
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
// Nemotron3NanoCollectingThinking - buffer and look for end markers
p.buffer.WriteString(s)
bufStr := p.buffer.String()
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
thinkIdx := strings.Index(bufStr, nemotronThinkClose)
toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
var endIdx int = -1
var remainder string
if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
endIdx = thinkIdx
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
} else if toolIdx != -1 {
endIdx = toolIdx
remainder = bufStr[toolIdx:] // Include <tool_call> tag
}
if endIdx != -1 {
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(ambig)
if unambig != "" {
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
if remainder == "" {
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
} else {
p.state = Nemotron3NanoCollectingContent
content, _, calls, err = p.toolParser.Add(remainder, done)
}
return nil, false
case Nemotron3NanoCollectingToolCalls:
if strings.Contains(bufStr, nemotronToolCallClose) {
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
var events []nemotronEvent
if tc, err := p.parseToolCall(split[0]); err == nil {
events = append(events, nemotronEventToolCall{toolCall: tc})
}
if !strings.Contains(remaining, nemotronToolCallOpen) {
p.state = Nemotron3NanoCollectingContent
}
return events, true
}
return nil, false
return content, thinking, calls, err
}
return nil, false
// No end marker - emit unambiguous thinking
thinking = p.emitThinking(bufStr)
return "", thinking, nil, nil
}
var (
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
)
// emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
// Check for partial </think> or <tool_call> at end
thinkOverlap := overlap(bufStr, nemotronThinkClose)
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
maxOverlap := max(thinkOverlap, toolOverlap)
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
toolCall := api.ToolCall{}
// Extract function name
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
if len(fnMatch) < 2 {
return toolCall, nil
}
toolCall.Function.Name = fnMatch[1]
// Extract parameters
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
for _, match := range paramMatches {
if len(match) >= 3 {
paramName := match[1]
paramValue := strings.TrimSpace(match[2])
// Try to parse as typed value based on tool definition
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
}
if maxOverlap > 0 {
unambiguous := bufStr[:len(bufStr)-maxOverlap]
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
return unambiguous
}
return toolCall, nil
}
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
// Find the matching tool to get parameter type
var paramType api.PropertyType
for _, tool := range p.tools {
if tool.Function.Parameters.Properties != nil {
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
paramType = prop.Type
break
}
}
}
return parseValue(raw, paramType)
// No partial tags - emit all but trailing whitespace
wsLen := trailingWhitespaceLen(bufStr)
if wsLen > 0 {
unambiguous := bufStr[:len(bufStr)-wsLen]
p.buffer.Reset()
p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
return unambiguous
}
// Nothing to hold back
p.buffer.Reset()
return bufStr
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/ollama/ollama/api"
)
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
func TestNemotron3NanoParser(t *testing.T) {
tests := []struct {
name string
@@ -17,18 +19,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking string
expectedCalls []api.ToolCall
}{
{
name: "simple content - no thinking",
input: "Hello, how can I help you?",
thinkValue: nil,
expectedContent: "Hello, how can I help you?",
},
{
name: "simple content - thinking disabled",
input: "Hello, how can I help you?",
thinkValue: &api.ThinkValue{Value: false},
expectedContent: "Hello, how can I help you?",
},
{
name: "thinking then content",
input: "Let me think about this...</think>\nHere is my answer.",
@@ -43,69 +33,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
expectedContent: "The answer is 42.",
},
{
name: "simple tool call",
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{
name: "content then tool call",
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
},
},
},
},
{
name: "tool call with multiple parameters",
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
{
name: "multiple tool calls",
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
},
},
},
},
{
name: "thinking then tool call",
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
@@ -135,19 +62,6 @@ func TestNemotron3NanoParser(t *testing.T) {
},
},
},
{
name: "tool call with multiline parameter value",
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
},
},
},
},
{
name: "empty thinking block - immediate close",
input: "</think>\nHere is my answer.",
@@ -161,18 +75,6 @@ func TestNemotron3NanoParser(t *testing.T) {
thinkValue: &api.ThinkValue{Value: false},
expectedContent: "</think>\nSome content after spurious tag.",
},
{
name: "tool call with no function name - returns empty tool call",
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
},
{
name: "content with newlines preserved",
input: "Line 1\n\nLine 2\n\n\nLine 3",
thinkValue: nil,
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
},
{
name: "thinking with only whitespace after close tag",
input: "My thoughts...</think> \n\t\n Content here.",
@@ -180,25 +82,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking: "My thoughts...",
expectedContent: "Content here.",
},
{
name: "unicode content",
input: "Hello 世界! 🌍 Ñoño",
thinkValue: nil,
expectedContent: "Hello 世界! 🌍 Ñoño",
},
{
name: "tool call with numeric parameter",
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{"value": "42"}),
},
},
},
},
}
for _, tt := range tests {
@@ -233,6 +116,8 @@ func TestNemotron3NanoParser(t *testing.T) {
}
}
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
// Tool call streaming is tested in qwen3coder_test.go.
func TestNemotron3NanoParser_Streaming(t *testing.T) {
tests := []struct {
name string
@@ -242,18 +127,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking string
expectedCalls []api.ToolCall
}{
{
name: "streaming content character by character",
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
thinkValue: nil,
expectedContent: "Hello, world!",
},
{
name: "streaming content small tokens",
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
thinkValue: nil,
expectedContent: "Hello, how can I help you today?",
},
{
name: "streaming thinking then content - granular",
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
@@ -268,45 +141,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "Step 1: Analyze\nStep 2: Process",
expectedContent: "The answer.",
},
{
name: "streaming tool call - highly granular",
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{
name: "streaming content then tool call - granular",
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
thinkValue: nil,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
},
},
},
},
{
name: "tool call tag split character by character",
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
},
{
name: "thinking close tag split character by character",
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
@@ -321,22 +155,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "Thinking...",
expectedContent: "Content here.",
},
{
name: "tool call with multiple parameters - streaming",
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
{
name: "thinking then content then tool call - streaming",
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
@@ -352,45 +170,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
},
},
},
{
name: "multiple tool calls - streaming",
chunks: []string{
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
"</function>", "\n", "</tool_call>", "\n",
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
"</function>\n", "</tool_call>",
},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
},
},
},
},
{
name: "tool call with multiline parameter - streaming",
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
},
},
},
},
{
name: "empty thinking block",
chunks: []string{"</think>", "\n", "Just content."},
@@ -398,12 +177,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "",
expectedContent: "Just content.",
},
{
name: "empty input chunks interspersed",
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
thinkValue: nil,
expectedContent: "Hello world!",
},
{
name: "tool call immediately after think close - no content",
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
@@ -418,25 +191,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
},
},
},
{
name: "tool call with empty parameter value",
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: testArgs(map[string]any{"name": ""}),
},
},
},
},
{
name: "partial tool call tag at end - buffered",
chunks: []string{"Here's some content", "<tool"},
thinkValue: nil,
expectedContent: "Here's some content",
},
}
for _, tt := range tests {
@@ -572,3 +326,65 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
}
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
// but the model outputs content + tool call WITHOUT the </think> tag.
// The parser should still parse the tool call (content before is treated as thinking).
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
chunks := []string{
"Let", " me", " analyze", " this", ".", "\n",
"<tool_call>", "\n",
"<function=get_weather>", "\n",
"<parameter=city>", "Paris", "</parameter>", "\n",
"</function>", "\n",
"</tool_call>",
}
p := &Nemotron3NanoParser{}
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
var allContent string
var allThinking string
var allCalls []api.ToolCall
for _, chunk := range chunks {
content, thinking, calls, err := p.Add(chunk, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
}
// Drain
content, thinking, calls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
expectedThinking := "Let me analyze this."
expectedCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
}
if allContent != "" {
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
}
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
}

View File

@@ -68,6 +68,12 @@ func ParserForName(name string) Parser {
return &Nemotron3NanoParser{}
case "functiongemma":
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}

View File

@@ -91,6 +91,37 @@ func TestQwenParserStreaming(t *testing.T) {
},
},
},
{
desc: "tool call tags split character by character",
steps: []step{
{input: "<", wantEvents: []qwenEvent{}},
{input: "t", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "_", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: ">", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "b", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "<", wantEvents: []qwenEvent{}},
{input: "/", wantEvents: []qwenEvent{}},
{input: "t", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "_", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
},
},
{
desc: "trailing whitespace between content and tool call",
steps: []step{

View File

@@ -96,3 +96,11 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}

110
model/renderers/glm46.go Normal file
View File

@@ -0,0 +1,110 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GLM46Renderer struct{}
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>\n")
}
return sb.String(), nil
}

View File

@@ -0,0 +1,223 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
skip string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip != "" {
t.Skip(tt.skip)
}
renderer := &GLM46Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More