Compare commits

..

10 Commits

Author SHA1 Message Date
ParthSareen
5c3bf414ef close to working 2025-12-08 18:17:56 -08:00
ParthSareen
0a9862a383 wip 2025-12-08 15:20:52 -08:00
ParthSareen
f475cc365a wip 2025-12-08 14:42:18 -08:00
ParthSareen
dd3306d3a0 renderers/parsers: olmo3 instruct 2025-12-08 13:23:11 -08:00
nicole pardal
57c1d7db9a fixed generation issue 2025-12-08 00:35:49 -08:00
nicole pardal
91d6370a62 removed original olmo support 2025-12-01 14:17:46 -08:00
nicole pardal
38a2a6468f removed olmo1 support 2025-12-01 14:14:31 -08:00
nicole pardal
064ec63ddf lint 2025-11-26 20:05:25 -08:00
nicole pardal
fd959fbf7a updated converter 2025-11-26 19:42:34 -08:00
nicole pardal
cfc9729edf olmo model initial 2025-11-25 15:49:09 -08:00
702 changed files with 33019 additions and 99709 deletions

2
.gitattributes vendored
View File

@@ -19,8 +19,6 @@ ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -16,15 +16,13 @@ jobs:
outputs:
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
VERSION: ${{ steps.goflags.outputs.VERSION }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
- name: Set environment
id: goflags
run: |
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
darwin-build:
runs-on: macos-14-xlarge
@@ -55,9 +53,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- run: |
./scripts/build_darwin.sh
- name: Log build results
@@ -68,7 +63,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -191,7 +185,7 @@ jobs:
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
@@ -255,9 +249,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- name: Verify gcc is actually clang
run: |
$ErrorActionPreference='Continue'
@@ -311,9 +302,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/download-artifact@v4
with:
pattern: depends-windows*
@@ -393,13 +381,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -532,7 +520,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -22,7 +22,6 @@ jobs:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.changes.outputs.changed }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
with:
@@ -38,7 +37,6 @@ jobs:
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
needs: [changes]
@@ -85,7 +83,7 @@ jobs:
- uses: actions/cache@v4
with:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel
@@ -180,7 +178,7 @@ jobs:
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
@@ -208,9 +206,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/setup-node@v4
with:
node-version: '20'

View File

@@ -1,51 +1,77 @@
version: "2"
linters:
default: none
enable:
- asasalint
- bidichk
- bodyclose
- containedctx
- copyloopvar
- errcheck
- errorlint
- exptostd
- gocheckcompilerdirectives
- gocritic
- govet
- ineffassign
- intrange
- makezero
- misspell
- modernize
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- perfsprint
- prealloc
- sloglint
- staticcheck
- unconvert
- unused
- usestdlibvars
- usetesting
- wastedassign
- whitespace
disable:
- errcheck
- usestdlibvars
settings:
govet:
disable:
- unusedresult
errcheck:
exclude-functions:
- fmt.Fprintf
perfsprint:
strconcat: false
concat-loop: false
staticcheck:
checks:
- all
- -QF* # disable quick fix suggestions
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- -SA1019
- -ST1000 # package comment format
- -ST1003 # underscores in package names
- -ST1005 # error strings should not be capitalized
- -ST1012 # error var naming (ErrFoo)
- -ST1016 # receiver name consistency
- -ST1020 # comment on exported function format
- -ST1021 # comment on exported type format
- -ST1022 # comment on exported var format
- -ST1023 # omit type from declaration
severity:
default: error
rules:
- linters:
- gofmt
- goimports
- intrange
severity: info
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- -ST1000
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- -ST1003
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- -ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- -ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- -ST1022
usestdlibvars:
http-method: false
http-status-code: false
formatters:
enable:
- gci
- gofmt
- gofumpt
settings:
gci:
sections:
- standard
- default
- localmodule

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -70,13 +54,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
# Define GGML version variables for shared library SOVERSION
# These are required by ggml/src/CMakeLists.txt for proper library versioning
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 0)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -163,48 +140,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,36 +131,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -181,7 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -200,7 +171,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=ec98e2002
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
.PHONY: help
help:
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
$(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR)
.PHONY: format-patches
.PHONE: format-patches
format-patches: llama/patches
git -C $(WORKDIR) format-patch \
--no-signature \
@@ -66,11 +66,7 @@ format-patches: llama/patches
-o $(realpath $<) \
$(FETCH_HEAD)
.PHONY: clean
.PHONE: clean
clean: checkout
@git -C $(WORKDIR) am --abort || true
$(RM) llama/patches/.*.patched
.PHONY: print-base
print-base:
@echo $(FETCH_HEAD)

View File

@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {
@@ -347,7 +340,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -15,19 +15,19 @@ func main() {
}
messages := []api.Message{
{
api.Message{
Role: "system",
Content: "Provide very brief, concise responses",
},
{
api.Message{
Role: "user",
Content: "Name some unusual animals",
},
{
api.Message{
Role: "assistant",
Content: "Monotreme, platypus, echidna",
},
{
api.Message{
Role: "user",
Content: "which of these is the most dangerous?",
},

View File

@@ -3,7 +3,6 @@ package api
import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"math"
"os"
@@ -15,7 +14,6 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model"
)
@@ -229,79 +227,13 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"`
}
// ToolCallFunctionArguments holds tool call arguments in insertion order.
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
if t == nil || t.om == nil {
return "{}"
}
bts, _ := json.Marshal(t.om)
bts, _ := json.Marshal(t)
return string(bts)
}
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
@@ -350,78 +282,12 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt))
}
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
}
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
@@ -470,11 +336,11 @@ func mapToTypeScriptType(jsonType string) string {
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]ToolProperty `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {
@@ -687,9 +553,6 @@ type CreateRequest struct {
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -740,7 +603,6 @@ type ShowResponse struct {
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].

View File

@@ -11,24 +11,6 @@ import (
"github.com/stretchr/testify/require"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
@@ -327,9 +309,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
input: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
},
@@ -337,9 +319,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
name: "no required",
input: ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
},
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
},
@@ -357,7 +339,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"message": "hi"}),
Arguments: ToolCallFunctionArguments{"message": "hi"},
}
data, err := json.Marshal(fn)
@@ -522,116 +504,6 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
}
}
func TestToolPropertyNestedProperties(t *testing.T) {
tests := []struct {
name string
input string
expected ToolProperty
}{
{
name: "nested object properties",
input: `{
"type": "object",
"description": "Location details",
"properties": {
"address": {
"type": "string",
"description": "Street address"
},
"city": {
"type": "string",
"description": "City name"
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{
"address": {
Type: PropertyType{"string"},
Description: "Street address",
},
"city": {
Type: PropertyType{"string"},
Description: "City name",
},
}),
},
},
{
name: "deeply nested properties",
input: `{
"type": "object",
"description": "Event",
"properties": {
"location": {
"type": "object",
"description": "Location",
"properties": {
"coordinates": {
"type": "object",
"description": "GPS coordinates",
"properties": {
"lat": {"type": "number", "description": "Latitude"},
"lng": {"type": "number", "description": "Longitude"}
}
}
}
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{
"location": {
Type: PropertyType{"object"},
Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{
"coordinates": {
Type: PropertyType{"object"},
Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}),
},
}),
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
// Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop)
require.NoError(t, err)
var prop2 ToolProperty
err = json.Unmarshal(data, &prop2)
require.NoError(t, err)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
})
}
}
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
@@ -643,12 +515,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
}),
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
@@ -665,7 +537,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s
return s
}(),
Properties: testPropsMap(map[string]ToolProperty{}),
Properties: map[string]ToolProperty{},
},
expected: "",
},
@@ -678,235 +550,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@@ -273,6 +273,10 @@ func main() {
Handler: uiServer.Handler(),
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
// Start the UI server
slog.Info("starting ui server", "port", port)
go func() {
@@ -316,17 +320,6 @@ func main() {
slog.Debug("no URL scheme request to handle")
}
go func() {
slog.Debug("waiting for ollama server to be ready")
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
slog.Warn("ollama server not ready, continuing anyway", "error", err)
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
}()
osRun(cancel, hasCompletedFirstRun, startHidden)
slog.Info("shutting down desktop server")
@@ -368,7 +361,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
return false
}
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
if err != nil {
slog.Debug("failed to call local auth endpoint", "error", err)
return false

View File

@@ -191,6 +191,13 @@ func LaunchNewApp() {
C.launchApp(appName)
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem()

View File

@@ -263,6 +263,11 @@ func createLoginShortcut() error {
return nil
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() {
}

View File

@@ -169,47 +169,37 @@ DlgResult fileDlg(FileDlgParams* params) {
}
NSArray* urls = [panel URLs];
if([urls count] == 0) {
return DLG_CANCEL;
}
if(self->params->allowMultiple) {
if(self->params->allowMultiple && [urls count] >= 1) {
// For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf;
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
}
return DLG_OK;

View File

@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int
func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
}
func err() error {

View File

@@ -224,7 +224,9 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models
} else {
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
}
}
if settings.ContextLength > 0 {

View File

@@ -469,24 +469,26 @@ export class HealthResponse {
}
export class User {
id: string;
email: string;
name: string;
bio?: string;
avatarurl?: string;
firstname?: string;
lastname?: string;
plan?: string;
email: string;
avatarURL: string;
plan: string;
bio: string;
firstName: string;
lastName: string;
overThreshold: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"];
this.email = source["email"];
this.name = source["name"];
this.bio = source["bio"];
this.avatarurl = source["avatarurl"];
this.firstname = source["firstname"];
this.lastname = source["lastname"];
this.email = source["email"];
this.avatarURL = source["avatarURL"];
this.plan = source["plan"];
this.bio = source["bio"];
this.firstName = source["firstName"];
this.lastName = source["lastName"];
this.overThreshold = source["overThreshold"];
}
}
export class Attachment {

View File

@@ -15,7 +15,7 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
import { API_BASE } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,6 +27,7 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -41,50 +42,44 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
}
export async function fetchUser(): Promise<User | null> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
try {
const response = await fetch(`${API_BASE}/api/v1/me`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const userData: User = await response.json();
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
if (response.ok) {
const userData: User = await response.json();
return userData;
}
return userData;
}
if (response.status === 401 || response.status === 403) {
return null;
} catch (error) {
console.error("Error fetching user:", error);
return null;
}
throw new Error(`Failed to fetch user: ${response.status}`);
}
export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
const response = await fetch(`${API_BASE}/api/v1/connect`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.status === 401) {
const data = await response.json();
if (data.signin_url) {
return data.signin_url;
}
if (!response.ok) {
throw new Error("Failed to fetch connect URL");
}
throw new Error("Failed to fetch connect URL");
const data = await response.json();
return data.connect_url;
}
export async function disconnectUser(): Promise<void> {
const response = await fetch(`${API_BASE}/api/signout`, {
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
method: "POST",
headers: {
"Content-Type": "application/json",
@@ -209,10 +204,12 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data),
}));
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
// Only send think parameter when actually requesting thinking
// Don't send false as it causes issues with some providers
const shouldSendThink =
think !== undefined &&
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
((typeof think === "boolean" && think) ||
(typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST",
@@ -394,8 +391,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function fetchHealth(): Promise<boolean> {
try {
// Use the /api/version endpoint as a health check
const response = await fetch(`${API_BASE}/api/version`, {
const response = await fetch(`${API_BASE}/api/v1/health`, {
method: "GET",
headers: {
"Content-Type": "application/json",
@@ -404,8 +400,7 @@ export async function fetchHealth(): Promise<boolean> {
if (response.ok) {
const data = await response.json();
// If we get a version back, the server is healthy
return !!data.version;
return data.healthy || false;
}
return false;

View File

@@ -299,9 +299,9 @@ export default function Settings() {
</Button>
</div>
</div>
{user?.avatarurl && (
{user?.avatarURL && (
<img
src={user.avatarurl}
src={user.avatarURL}
alt={user?.name}
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
onError={(e) => {

View File

@@ -50,33 +50,21 @@ export default function Thinking({
// Position content to show bottom when collapsed
useEffect(() => {
if (isCollapsed && contentRef.current && wrapperRef.current) {
requestAnimationFrame(() => {
if (!contentRef.current || !wrapperRef.current) return;
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
});
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
setHasOverflow(false);
}
} else if (contentRef.current) {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
}, [thinking, isCollapsed]);
useEffect(() => {
if (activelyThinking && wrapperRef.current && !isCollapsed) {
// When expanded and actively thinking, scroll to bottom
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
}
}, [thinking, activelyThinking, isCollapsed]);
const handleToggle = () => {
setIsCollapsed(!isCollapsed);
setHasUserInteracted(true);

View File

@@ -7,7 +7,6 @@ import { createQueryBatcher } from "./useQueryBatcher";
import { useRefetchModels } from "./useModels";
import { useStreamingContext } from "@/contexts/StreamingContext";
import { useSettings } from "./useSettings";
import { getModelCapabilities } from "@/api";
export const useChats = () => {
return useQuery({
@@ -607,24 +606,6 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["staleModels"], newStaleMap);
queryClient.invalidateQueries({ queryKey: ["models"] });
// Fetch fresh capabilities for the downloaded model
getModelCapabilities(selectedModel.model)
.then((capabilities) => {
queryClient.setQueryData(
["modelCapabilities", selectedModel.model],
capabilities,
);
})
.catch((error) => {
console.error(
"Failed to fetch capabilities after download:",
error,
);
queryClient.invalidateQueries({
queryKey: ["modelCapabilities", selectedModel.model],
});
});
}
break;
}

View File

@@ -0,0 +1,114 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import { pullModel } from "@/api";
import { useSelectedModel } from "./useSelectedModel";
import { useSettings } from "./useSettings";
interface DownloadProgress {
status: string;
digest?: string;
total?: number;
completed?: number;
done?: boolean;
}
export function useDownloadModel(chatId?: string) {
const queryClient = useQueryClient();
const { selectedModel } = useSelectedModel(chatId);
const { setSettings } = useSettings();
const [downloadProgress, setDownloadProgress] =
useState<DownloadProgress | null>(null);
const [abortController, setAbortController] =
useState<AbortController | null>(null);
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
new Set(),
);
const mutation = useMutation({
mutationFn: async (modelName: string) => {
const controller = new AbortController();
setAbortController(controller);
setDownloadProgress({ status: "Starting download..." });
if (chatId) {
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
}
try {
for await (const progress of pullModel(modelName, controller.signal)) {
setDownloadProgress(progress);
if (progress.status === "success") {
// Update selected model to indicate it's now available locally
if (selectedModel && selectedModel.model === modelName) {
setSettings({ SelectedModel: modelName });
}
// Invalidate models query to refresh the list
await queryClient.invalidateQueries({ queryKey: ["models"] });
break;
}
}
} finally {
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
},
onSuccess: () => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
},
onError: (error: Error) => {
const status =
error.name === "AbortError" ? "Download cancelled" : "Download failed";
setDownloadProgress({ status, done: true });
// Clear error message after delay
const delay = error.name === "AbortError" ? 1500 : 3000;
setTimeout(() => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}, delay);
},
});
const cancelDownload = () => {
if (abortController) {
abortController.abort();
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
};
return {
downloadModel: mutation.mutate,
isDownloading:
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
downloadProgress:
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
error: mutation.error,
cancelDownload,
};
}

View File

@@ -1,20 +1,29 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
export function useUser() {
const queryClient = useQueryClient();
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
// Wait for initial data to be loaded
useEffect(() => {
const initialPromise = window.__initialUserDataPromise;
if (initialPromise) {
initialPromise.finally(() => {
setInitialDataLoaded(true);
});
} else {
setInitialDataLoaded(true);
}
}, []);
const userQuery = useQuery({
queryKey: ["user"],
queryFn: async () => {
const result = await fetchUser();
return result;
},
queryFn: () => fetchUser(),
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
retry: 10,
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
refetchOnMount: true, // Always fetch when component mounts
initialData: null, // Start with null to prevent flashing
});
// Mutation to refresh user data
@@ -40,15 +49,14 @@ export function useUser() {
},
});
const isLoading = userQuery.isLoading || userQuery.isFetching;
const isAuthenticated = Boolean(userQuery.data?.name);
return {
user: userQuery.data,
isLoading,
isLoading:
!initialDataLoaded ||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
isError: userQuery.isError,
error: userQuery.error,
isAuthenticated,
isAuthenticated: Boolean(userQuery.data?.name),
refreshUser: refreshUser.mutate,
isRefreshing: refreshUser.isPending,
refetchUser: userQuery.refetch,

View File

@@ -8,6 +8,3 @@ export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;
export const OLLAMA_DOT_COM =
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";

View File

@@ -147,7 +147,6 @@ export const highlighterPromise = createHighlighter({
"c",
"cpp",
"sql",
"swift",
"yaml",
"markdown",
],

View File

@@ -5,6 +5,13 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { routeTree } from "./routeTree.gen";
import { fetchUser } from "./api";
import { StreamingProvider } from "./contexts/StreamingContext";
import { User } from "@/gotypes";
declare global {
interface Window {
__initialUserDataPromise?: Promise<User | null>;
}
}
const queryClient = new QueryClient({
defaultOptions: {
@@ -17,11 +24,27 @@ const queryClient = new QueryClient({
},
});
fetchUser().then((userData) => {
if (userData) {
// Track initial user data fetch
let initialUserDataPromise: Promise<User | null> | null = null;
// Initialize user data on app startup
const initializeUserData = async () => {
try {
const userData = await fetchUser();
queryClient.setQueryData(["user"], userData);
return userData;
} catch (error) {
console.error("Error initializing user data:", error);
queryClient.setQueryData(["user"], null);
return null;
}
});
};
// Start initialization immediately and track the promise
initialUserDataPromise = initializeUserData();
// Export the promise so hooks can await it
window.__initialUserDataPromise = initialUserDataPromise;
const router = createRouter({
routeTree,

View File

@@ -101,14 +101,15 @@ type HealthResponse struct {
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
Plan string `json:"plan"`
Bio string `json:"bio"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
OverThreshold bool `json:"overThreshold"`
}
type Attachment struct {

View File

@@ -12,17 +12,18 @@ import (
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools"
@@ -117,66 +118,40 @@ func (s *Server) log() *slog.Logger {
// ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler {
var (
proxy http.Handler
proxyMu sync.Mutex
)
ollamaHost := os.Getenv("OLLAMA_HOST")
if ollamaHost == "" {
ollamaHost = "http://127.0.0.1:11434"
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyMu.Lock()
p := proxy
proxyMu.Unlock()
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
ollamaHost = "http://" + ollamaHost
}
if p == nil {
proxyMu.Lock()
if proxy == nil {
var err error
for i := range 2 {
if i > 0 {
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
time.Sleep(1 * time.Second)
}
target, err := url.Parse(ollamaHost)
if err != nil {
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
})
}
err = WaitForServer(context.Background(), 10*time.Second)
if err == nil {
break
}
}
s.log().Info("configuring ollama proxy", "target", target.String())
if err != nil {
proxyMu.Unlock()
s.log().Error("ollama server not ready after retries", "error", err)
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
return
}
proxy := httputil.NewSingleHostReverseProxy(target)
target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String())
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
originalDirector := newProxy.Director
newProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
proxy = newProxy
p = newProxy
} else {
p = proxy
}
proxyMu.Unlock()
}
p.ServeHTTP(w, r)
})
return proxy
}
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
@@ -289,10 +264,11 @@ func (s *Server) Handler() http.Handler {
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/version", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy)
mux.Handle("POST /api/me", ollamaProxy)
mux.Handle("POST /api/signout", ollamaProxy)
mux.Handle("GET /api/v1/me", handle(s.me))
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
mux.Handle("GET /api/v1/health", handle(s.health))
// React app - catch all non-API routes and serve the React app
mux.Handle("GET /", s.appHandler())
@@ -362,7 +338,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
}
// UserData fetches user data from ollama.com API for the current ollama key
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
@@ -373,7 +349,7 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var user api.UserResponse
var user responses.User
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse user response: %w", err)
}
@@ -392,27 +368,29 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return &user, nil
}
// WaitForServer waits for the Ollama server to be ready
func WaitForServer(ctx context.Context, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
func waitForServer(ctx context.Context) error {
timeout := time.Now().Add(10 * time.Second)
// TODO: this avoids an error on first load of the app
// however we should either show a loading state or
// wait for the Ollama server to be ready before redirecting
for {
c, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := c.Version(ctx); err == nil {
slog.Debug("ollama server is ready")
return nil
break
}
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
}
time.Sleep(10 * time.Millisecond)
}
return errors.New("timeout waiting for Ollama server to be ready")
return nil
}
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
return err
}
waitForServer(r.Context())
id, err := uuid.NewV7()
if err != nil {
@@ -997,7 +975,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed
toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil)
@@ -1460,6 +1438,129 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
})
}
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
user, err := s.UserData(r.Context())
if err != nil {
// If fetching from API fails, try to return cached user data if available
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
s.log().Info("API request failed, returning cached user data", "error", err)
responseUser := &responses.User{
Name: cachedUser.Name,
Email: cachedUser.Email,
Plan: cachedUser.Plan,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responseUser)
}
s.log().Error("failed to get user data", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get user data",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(user)
}
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
if err := s.Store.ClearUser(); err != nil {
s.log().Warn("failed to clear cached user data", "error", err)
}
// Get the SSH public key to encode for the delete request
pubKey, err := ollamaAuth.GetPublicKey()
if err != nil {
s.log().Error("failed to get public key", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get public key",
})
}
// Encode the key using base64 URL encoding
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
if err != nil {
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.log().Error("disconnect request failed", "status", resp.StatusCode)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
}
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
if err != nil {
s.log().Error("failed to build connect URL", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to build connect URL",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{
"connect_url": connectURL,
})
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
healthy := false
c, err := api.ClientFromEnvironment()
if err == nil {
if _, err := c.Version(r.Context()); err == nil {
healthy = true
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responses.HealthResponse{
Healthy: healthy,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
@@ -1558,13 +1659,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok {
@@ -1572,7 +1673,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""),
}
tool.Function.Parameters.Properties.Set(propName, prop)
tool.Function.Parameters.Properties[propName] = prop
}
}
}

View File

@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread
l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
t.app.UIRun(path)
case WM_COPYDATA:
// Handle URL scheme requests from other instances
if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
if cds.DwData == 1 { // Our identifier for URL scheme messages
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string
data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
urlScheme := string(data)
handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success

View File

@@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa
```
go build -o ollama-bench bench.go
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
./bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
@@ -29,32 +29,31 @@ go run bench.go -model gpt-oss:20b -epochs 3
### Basic Example
```
./ollama-bench -model gemma3 -epochs 6
./bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |

View File

@@ -48,8 +48,8 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
fmt.Printf("sysname: %s\n", runtime.GOOS)
fmt.Printf("machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
@@ -147,17 +147,6 @@ func BenchmarkChat(fOpt flagOptions) error {
return err
}
var out io.Writer = os.Stdout
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
return err
}
defer f.Close()
out = f
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
@@ -252,14 +241,13 @@ func BenchmarkChat(fOpt flagOptions) error {
},
}
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}

View File

@@ -45,9 +45,6 @@ import (
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -98,10 +95,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -463,15 +456,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -533,10 +517,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
@@ -563,11 +543,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@@ -837,11 +812,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -973,9 +943,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
@@ -1463,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
@@ -1784,11 +1751,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -291,31 +291,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("min version", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Requires: "0.14.0",
}, false, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
requires 0.14.0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {

View File

@@ -40,7 +40,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")

148
cmd/testolmo/main.go Normal file
View File

@@ -0,0 +1,148 @@
package main
import (
"context"
"fmt"
"log"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
_ "github.com/ollama/ollama/model/models" // Register all models
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/sample"
)
func main() {
modelPath := "/Users/parth/.ollama/models/blobs/sha256-a87e10578f328b087f888ac7bd1018555e26028a1130980f20312b4de3a10d70"
fmt.Println("Loading OLMo model...")
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
if err != nil {
log.Fatal(err)
}
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
log.Fatal(err)
}
fmt.Println("✅ Model loaded successfully!")
// Initialize the cache
cache := m.Config().Cache
if cache != nil {
// Initialize with reasonable defaults:
// - dtype: F16
// - maxSequences: 1 (single sequence)
// - capacity: 2048 (context length)
// - maxBatch: 512
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
}
// Use the olmo3 renderer to format the prompt properly
messages := []api.Message{
{Role: "user", Content: "wagwan"},
}
// prompt := "Question: What is machine learning? Answer:"
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
if err != nil {
log.Fatal(err)
}
// prompt = prompt[:len(prompt)]
// prompt := "Question: What is machine learning? Answer:"
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
// Generate 20 tokens
maxTokens := 20
generated := make([]int32, 0, maxTokens)
// Create sampler (temperature=0 for greedy sampling)
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
for i := 0; i < maxTokens; i++ {
// Create a new context for each generation step to avoid memory buildup
ctx := m.Backend().NewContext()
var inputTokens []int32
var positions []int32
if i == 0 {
// First iteration: process all prompt tokens
inputTokens = tokens
positions = make([]int32, len(tokens))
for j := range positions {
positions[j] = int32(j)
}
} else {
// Subsequent iterations: only process the newly generated token
// The last token is at position len(tokens)-1 (its index in the sequence)
inputTokens = []int32{tokens[len(tokens)-1]}
positions = []int32{int32(len(tokens) - 1)}
}
sequences := make([]int, len(inputTokens))
// All tokens belong to sequence 0
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
batch := input.Batch{
Inputs: inputsTensor,
Positions: positions,
Sequences: sequences,
Outputs: outputs,
}
// Forward pass (model.Forward handles cache.StartForward internally)
logits, err := model.Forward(ctx, m, batch)
if err != nil {
ctx.Close()
log.Fatal(err)
}
logits = logits.Contiguous(ctx)
ctx.Forward(logits).Compute(logits)
logitValues := logits.Floats()
// Sample next token
nextToken, err := sampler.Sample(logitValues)
if err != nil {
ctx.Close()
log.Fatal(err)
}
// Close context before moving to next iteration
ctx.Close()
generated = append(generated, nextToken)
tokens = append(tokens, nextToken)
// Decode and print
decoded, _ := tp.Decode([]int32{nextToken})
fmt.Print(decoded)
// Stop on EOS or <|im_end|>
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
break
}
// Check if we generated <|im_end|> (stop token for chat)
if decoded == "<|im_end|>" {
break
}
}
fmt.Println("\n\n✅ Generation completed!")
fullText, _ := tp.Decode(generated)
fmt.Printf("Generated: %s\n", fullText)
}

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
}
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -277,8 +182,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
@@ -297,37 +200,33 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "Olmo3ForCausalLM":
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel":
conv = &bertModel{}
case "NomicBertModel", "NomicBertMoEModel":
conv = &nomicbertModel{}
case "CohereForCausalLM":
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +248,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -371,7 +257,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -1,173 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type deepseek2Model struct {
ModelParameters // architectures, vocab_size
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
ScoringFunc string `json:"scoring_func"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
RopeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
MScaleAllDim float32 `json:"mscale_all_dim"`
} `json:"rope_scaling"`
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"
kv["deepseek2.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["deepseek2.attention.head_count"] = numHeads
kv["deepseek2.attention.head_count_kv"] = numKVHeads
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
kv["deepseek2.attention.value_length"] = p.VHeadDim
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
kv["deepseek2.embedding_length"] = p.HiddenSize
kv["deepseek2.expert_count"] = p.ExpertCount
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
var scoringFunc uint32
switch p.ScoringFunc {
case "softmax":
// not currently supported in the model, but needed for Deepseek-OCR
scoringFunc = 1
case "sigmoid":
scoringFunc = 2
}
kv["deepseek2.expert_gating_func"] = scoringFunc
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
kv["tokenizer.ggml.pre"] = "deepseek-v3"
return kv
}
func (p *deepseek2Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"language_model.", "",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -2,7 +2,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -25,26 +26,16 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
RopeScaling *struct {
Type string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
} `json:"rope_scaling"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
@@ -53,7 +44,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
@@ -90,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) KV {
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
// The sliding window pattern is either provided as the sliding_window_pattern
// key (an int) or as the layer_types key (a list of strings).
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for i := range numBlocks {
var isLocal bool
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
isLocal = p.LayerTypes[i] == "sliding_attention"
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
isLocal = (i+1)%*p.SlidingWindowPattern != 0
}
if !yield(isLocal) {
break
}
}
})
}
if p.FinalLogitSoftcap > 0 {
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
}
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
kv["gemma3.rope.scaling.type"] = "yarn"
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
}
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -29,17 +29,6 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -52,15 +41,12 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
@@ -75,25 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
}
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
}
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
@@ -105,7 +74,7 @@ func (p *mistral3Model) KV(t *Tokenizer) KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@@ -1,181 +0,0 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -1,213 +0,0 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type nomicbertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeFreqBase float32 `json:"rope_theta"`
normalizeEmbeddings bool
PoolingType uint32
// MoE parameters (only present in v2 models)
NumExperts uint32 `json:"num_local_experts"`
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
}
var (
_ ModelConverter = (*nomicbertModel)(nil)
_ moreParser = (*nomicbertModel)(nil)
)
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)
arch := "nomic-bert"
if p.MoEEveryNLayers > 0 {
arch += "-moe"
}
kv["general.architecture"] = arch
kv["attention.causal"] = false
kv["pooling_type"] = p.PoolingType
kv["normalize_embeddings"] = p.normalizeEmbeddings
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
kv["context_length"] = contextLength
}
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
kv["embedding_length"] = p.HiddenSize
}
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
kv["feed_forward_length"] = p.IntermediateSize
}
if headCount := p.NumAttentionHeads; headCount > 0 {
kv["attention.head_count"] = p.NumAttentionHeads
}
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
kv["attention.head_count_kv"] = p.NumKeyValueHeads
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.RopeFreqBase > 0 {
kv["rope.freq_base"] = p.RopeFreqBase
}
// MoE specific parameters (only if MoE is enabled)
if p.NumExperts > 0 {
kv["expert_count"] = p.NumExperts
}
if p.NumExpertsUsed > 0 {
kv["expert_used_count"] = p.NumExpertsUsed
}
if p.MoEEveryNLayers > 0 {
kv["moe_every_n_layers"] = p.MoEEveryNLayers
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
switch {
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
// noop - keep special tokens as-is
case strings.HasPrefix(e, "##"):
t.Tokens[i] = e[2:]
default:
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (nomicbertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"attention.self.qkv", "attn_qkv",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"mlp.up", "ffn_up",
"mlp.down", "ffn_down",
"mlp.router", "ffn_gate_inp",
"mlp.experts.up", "ffn_up_exps",
"mlp.experts.down", "ffn_down_exps",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}

View File

@@ -6,77 +6,54 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
ClampKQV float32 `json:"f_clamp_kqv"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
kv["olmo3.embedding_length"] = p.HiddenSize
kv["olmo3.feed_forward_length"] = p.IntermediateSize
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
kv["general.architecture"] = "olmo"
kv["olmo.block_count"] = p.NumHiddenLayers
kv["olmo.context_length"] = p.MaxPositionEmbeddings
kv["olmo.embedding_length"] = p.HiddenSize
kv["olmo.feed_forward_length"] = p.IntermediateSize
kv["olmo.attention.head_count"] = p.NumAttentionHeads
kv["olmo.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo3.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
}
kv["olmo.rope.freq_base"] = p.RopeTheta
} else {
kv["olmo.rope.freq_base"] = float32(10000.0)
}
if p.RMSNormEPS > 0 {
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["olmo.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.ClampKQV > 0 {
kv["olmo.attention.clamp_kqv"] = p.ClampKQV
}
if p.SlidingWindow > 0 {
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
kv["olmo.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
kv["olmo.attention.layer_types"] = p.LayerTypes
}
return kv

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -49,8 +49,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// temporary fix to handle gemma3 broken configs
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
}

View File

@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
@@ -99,9 +98,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
@@ -147,7 +143,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
wg.Add(1)
go func(i int) {
defer wg.Done()
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
devices[i].AddInitValidation(extraEnvs)
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
slog.Debug("filtering device which didn't fully initialize",
@@ -333,8 +329,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
defer cancel()
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
devFilter := ml.GetVisibleDevicesEnv(devices, false)
devFilter := ml.GetVisibleDevicesEnv(devices)
for dir := range libDirs {
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
@@ -489,16 +484,3 @@ func overrideWarnings() {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
@@ -507,7 +507,7 @@ The `message` object has the following fields:
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [
{
"function": {
"name": "get_weather",
"name": "get_temperature",
"arguments": {
"city": "Toronto"
}
}
},
}
]
},
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_weather"
"tool_name": "get_temperature",
}
],
"stream": false,
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
- `template`: (optional) the prompt template for the model
- `license`: (optional) a string or list of strings containing the license or licenses for the model
- `system`: (optional) a string containing the system prompt for the model
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
- `messages`: (optional) a list of message objects used to create a conversation
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples

View File

@@ -1,406 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

File diff suppressed because one or more lines are too long

View File

@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
```shell
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3",
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
"stream": false,
"tools": [
{
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3",
"messages": [
{"role": "user", "content": "What is the temperature in New York?"},
{"role": "user", "content": "What's the temperature in New York?"},
{
"role": "assistant",
"tool_calls": [
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
}
return temperatures.get(city, "Unknown")
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
# pass functions directly as tools in the tools list or as a JSON schema
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
},
]
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
const response = await ollama.chat({
model: 'qwen3',
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
return temperatures.get(city, 'Unknown')
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
while True:
stream = chat(
@@ -684,7 +684,7 @@ const getTemperatureTool = {
}
async function agentLoop() {
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
while (true) {
const stream = await ollama.chat({

View File

@@ -36,6 +36,7 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}],
"stream": false
}'
"
```
</Tab>
<Tab title="Python">

View File

@@ -49,8 +49,6 @@ Install prerequisites:
- [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
Then, configure and build the project:
@@ -59,17 +57,6 @@ cmake -B build
cmake --build build --config Release
```
> Building for Vulkan requires VULKAN_SDK environment variable:
>
> PowerShell
> ```powershell
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
> ```
> CMD
> ```cmd
> set VULKAN_SDK=C:\VulkanSDK\<version>
> ```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
@@ -78,7 +65,6 @@ cmake --build build --config Release
> ```
Lastly, run Ollama:
```shell
@@ -98,9 +84,7 @@ Install prerequisites:
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.

View File

@@ -32,9 +32,7 @@
"codeblocks": "system"
},
"contextual": {
"options": [
"copy"
]
"options": ["copy"]
},
"navbar": {
"links": [
@@ -54,9 +52,7 @@
"display": "simple"
},
"examples": {
"languages": [
"curl"
]
"languages": ["curl"]
}
},
"redirects": [
@@ -101,7 +97,6 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
@@ -144,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu).
Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
@@ -57,13 +57,8 @@ ollama ps
```
<Info>
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library:
> **NOTE:**
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> **NOTE:**
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server)
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -1,69 +0,0 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,7 +13,8 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -41,7 +41,6 @@ INSTRUCTION arguments
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
## Examples
@@ -150,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
@@ -249,16 +251,6 @@ MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
### REQUIRES
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
```
REQUIRES <version>
```
The version should be a valid Ollama version (e.g. 0.14.0).
## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View File

@@ -1,46 +0,0 @@
# extract-examples
Extracts code examples from MDX files to a temp directory so you can run them.
## Usage
```shell
go run docs/tools/extract-examples/main.go <mdx-file>
```
## Example
```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```
Output:
```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
- 01_basic.py
- 01_basic.js
- 01_basic.sh
- 02_responses.py
- 02_responses.js
- 02_responses.sh
- 03_vision.py
- 03_vision.js
- 03_vision.sh
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
To run examples:
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
npm install # for JS examples
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```
## How it works
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
- Writes all extracted files to a temp directory

View File

@@ -1,137 +0,0 @@
package main
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
os.Exit(1)
}
mdxFile := os.Args[1]
f, err := os.Open(mdxFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
defer f.Close()
// Create temp directory
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
os.Exit(1)
}
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
// Patterns
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
codeGroupStart := regexp.MustCompile("^<CodeGroup")
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
scanner := bufio.NewScanner(f)
inCodeBlock := false
inCodeGroup := false
var currentFile string
var content strings.Builder
count := 0
codeGroupNum := 0
for scanner.Scan() {
line := scanner.Text()
// Track CodeGroup boundaries
if codeGroupStart.MatchString(line) {
inCodeGroup = true
codeGroupNum++
continue
}
if codeGroupEnd.MatchString(line) {
inCodeGroup = false
continue
}
if inCodeBlock {
if line == "```" {
// End of code block - write file
if currentFile != "" {
outPath := filepath.Join(tempDir, currentFile)
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
} else {
fmt.Printf(" - %s\n", currentFile)
count++
}
}
inCodeBlock = false
currentFile = ""
content.Reset()
} else {
content.WriteString(line)
content.WriteString("\n")
}
} else {
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
inCodeBlock = true
filename := matches[2]
// Prefix with CodeGroup number if inside a CodeGroup
if inCodeGroup {
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
} else {
currentFile = filename
}
content.Reset()
}
}
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
os.Exit(1)
}
// Write package.json for JavaScript dependencies
packageJSON := `{
"name": "mdx-examples",
"type": "module",
"dependencies": {
"openai": "^4",
"ollama": "^0.5"
}
}
`
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
}
// Write pyproject.toml for Python dependencies
pyprojectTOML := `[project]
name = "mdx-examples"
version = "0.0.0"
dependencies = [
"openai",
"ollama",
]
`
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
}
fmt.Printf("\n")
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
fmt.Printf("\n")
fmt.Printf("To run examples:\n")
fmt.Printf("\n")
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
fmt.Printf("\n")
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
}

View File

@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,16 +6,13 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
)
type GGML struct {
@@ -241,34 +238,21 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",
"deepseek2",
"deepseekocr",
"gemma3",
"gemma3n",
"gptoss", "gpt-oss",
"llama4",
"mistral3",
"mllama",
"nomic-bert",
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
"olmo2",
}, kv.Architecture())
}
@@ -567,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
@@ -808,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention == ml.FlashAttentionEnabled {
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
@@ -826,14 +810,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
@@ -854,11 +830,8 @@ func (f GGML) SupportsFlashAttention() bool {
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"gptoss", "gpt-oss",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}
@@ -597,10 +597,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
var err error
switch v := v.(type) {
case int32:
err = writeGGUF(ws, ggufTypeInt32, v)
case int64:
err = writeGGUF(ws, ggufTypeInt64, v)
case uint32, FileType:
err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:
@@ -615,10 +611,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeInt32, v)
case *array[int32]:
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
case []int64:
err = writeGGUFArray(ws, ggufTypeInt64, v)
case *array[int64]:
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v)
case *array[uint32]:

View File

@@ -42,10 +42,6 @@ func TestWriteGGUF(t *testing.T) {
"general.architecture": "test",
"general.alignment": uint32(16),
"test.key": "value",
"test.int32_key": int32(-42),
"test.int64_key": int64(-9223372036854775808),
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
"attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
@@ -59,7 +55,7 @@ func TestWriteGGUF(t *testing.T) {
}
defer r.Close()
ff, err := Decode(r, -1)
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
@@ -69,19 +65,15 @@ func TestWriteGGUF(t *testing.T) {
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
"test.key": "value",
"test.int32_key": int32(-42),
"test.int64_key": int64(-9223372036854775808),
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
"test.attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 992,
Offset: 800,
items: []*Tensor{
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},

19
go.mod
View File

@@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/sync v0.12.0
golang.org/x/sys v0.36.0
)
require (
@@ -28,17 +28,13 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
)
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
@@ -48,7 +44,6 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
@@ -81,11 +76,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
golang.org/x/net v0.38.0 // indirect
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

39
go.sum
View File

@@ -14,11 +14,7 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -127,7 +123,6 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
@@ -148,8 +143,6 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -214,8 +207,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
@@ -233,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -264,8 +255,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -278,8 +267,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -289,8 +278,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -306,17 +295,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -330,8 +319,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -4,9 +4,7 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@@ -206,8 +204,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
}
if res.PromptEvalCount != 8 {
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 6 {
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -253,8 +251,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
}
if res.PromptEvalCount != 16 {
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 12 {
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -277,7 +275,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
cases := []struct {
name string
request api.EmbedRequest
check func(*testing.T, *api.EmbedResponse, error)
check func(*api.EmbedResponse, error)
}{
{
name: "target truncation",
@@ -285,7 +283,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Model: "all-minilm",
Input: "why",
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -302,11 +300,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -320,11 +317,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -338,21 +334,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "the input length exceeds the context length" {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error with context length of 1",
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -366,7 +362,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -379,7 +375,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue? Why is the sky blue? hi there my",
Options: map[string]any{"num_ctx": 16},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -389,8 +385,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
resp, err := embedTestHelper(ctx, client, t, req.request)
req.check(t, resp, err)
req.check(embedTestHelper(ctx, client, t, req.request))
})
}
}
@@ -414,230 +409,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
// Give each model its own budget to account for first-time pulls/loads
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(mctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(mctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
}
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()
// Test with progressively larger inputs
testCases := []struct {
name string
inputWords int
}{
{"medium_input_256_words", 256},
{"large_input_512_words", 512},
{"very_large_input_800_words", 800},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
words := make([]string, tc.inputWords)
for i := range words {
words[i] = "word"
}
input := strings.Join(words, " ")
req := api.EmbedRequest{
Model: model,
Input: input,
KeepAlive: &api.Duration{Duration: 30 * time.Second},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) == 0 {
t.Fatal("expected non-empty embedding")
}
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
})
}
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
// Pull the model if needed
if err := PullIfMissing(mctx, client, model); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: model,
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: model,
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
})
}
}

View File

@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
// Qwen 3 VL mixture of experts
model: "qwen3-vl:30b",
},
{
model: "ministral-3",
},
}
for _, v := range testCases {

View File

@@ -11,15 +11,6 @@ import (
"github.com/ollama/ollama/api"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second
@@ -39,7 +30,6 @@ func TestAPIToolCalling(t *testing.T) {
"mistral": 6,
"qwen2.5": 6,
"qwen2": 6,
"ministral-3": 20,
"mistral-nemo": 9,
"mistral-small": 16,
"mixtral:8x22b": 80,
@@ -66,12 +56,12 @@ func TestAPIToolCalling(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
}),
},
},
},
},

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
@@ -168,7 +167,6 @@ var (
"medllama2",
"megadolphin",
"minicpm-v",
"ministral-3",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
@@ -272,7 +270,6 @@ var (
"mistral",
"qwen2.5",
"qwen2",
"ministral-3",
"mistral-nemo",
"mistral-small",
"mixtral:8x22b",

View File

@@ -1,94 +0,0 @@
// Package orderedmap provides a generic ordered map that maintains insertion order.
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
package orderedmap
import (
"encoding/json"
"iter"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
// Map is a generic ordered map that maintains insertion order.
type Map[K comparable, V any] struct {
om *orderedmap.OrderedMap[K, V]
}
// New creates a new empty ordered map.
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
om: orderedmap.New[K, V](),
}
}
// Get retrieves a value by key.
func (m *Map[K, V]) Get(key K) (V, bool) {
if m == nil || m.om == nil {
var zero V
return zero, false
}
return m.om.Get(key)
}
// Set sets a key-value pair. If the key already exists, its value is updated
// but its position in the iteration order is preserved. If the key is new,
// it is appended to the end.
func (m *Map[K, V]) Set(key K, value V) {
if m == nil {
return
}
if m.om == nil {
m.om = orderedmap.New[K, V]()
}
m.om.Set(key, value)
}
// Len returns the number of entries.
func (m *Map[K, V]) Len() int {
if m == nil || m.om == nil {
return 0
}
return m.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (m *Map[K, V]) All() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
if m == nil || m.om == nil {
return
}
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
// ToMap converts to a regular Go map.
// Note: The resulting map does not preserve order.
func (m *Map[K, V]) ToMap() map[K]V {
if m == nil || m.om == nil {
return nil
}
result := make(map[K]V, m.om.Len())
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
result[pair.Key] = pair.Value
}
return result
}
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
if m == nil || m.om == nil {
return []byte("null"), nil
}
return json.Marshal(m.om)
}
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
// order of keys in the JSON input.
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
m.om = orderedmap.New[K, V]()
return json.Unmarshal(data, &m.om)
}

View File

@@ -1,348 +0,0 @@
package orderedmap
import (
"encoding/json"
"slices"
"testing"
)
func TestMap_BasicOperations(t *testing.T) {
m := New[string, int]()
// Test empty map
if m.Len() != 0 {
t.Errorf("expected Len() = 0, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on empty map to return false")
}
if v != 0 {
t.Errorf("expected zero value, got %d", v)
}
// Test Set and Get
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
if m.Len() != 3 {
t.Errorf("expected Len() = 3, got %d", m.Len())
}
v, ok = m.Get("a")
if !ok || v != 1 {
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("b")
if !ok || v != 2 {
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("c")
if !ok || v != 3 {
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
}
// Test updating existing key preserves position
m.Set("a", 10)
v, ok = m.Get("a")
if !ok || v != 10 {
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
}
if m.Len() != 3 {
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
}
}
func TestMap_InsertionOrderPreserved(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
m.Set("b", 4)
// Verify iteration order matches insertion order
var keys []string
var values []int
for k, v := range m.All() {
keys = append(keys, k)
values = append(values, v)
}
expectedKeys := []string{"z", "a", "m", "b"}
expectedValues := []int{1, 2, 3, 4}
if !slices.Equal(keys, expectedKeys) {
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
}
if !slices.Equal(values, expectedValues) {
t.Errorf("expected values %v, got %v", expectedValues, values)
}
}
func TestMap_UpdatePreservesPosition(t *testing.T) {
m := New[string, int]()
m.Set("first", 1)
m.Set("second", 2)
m.Set("third", 3)
// Update middle element
m.Set("second", 20)
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
// Order should still be first, second, third
expected := []string{"first", "second", "third"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
// JSON should preserve insertion order, not alphabetical
expected := `{"z":1,"a":2,"m":3}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
// JSON with non-alphabetical key order
jsonData := `{"z":1,"a":2,"m":3}`
m := New[string, int]()
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
// Verify iteration order matches JSON order
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
expected := []string{"z", "a", "m"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_JSONRoundTrip(t *testing.T) {
// Test that unmarshal -> marshal produces identical JSON
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
m := New[string, string]()
if err := json.Unmarshal([]byte(original), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != original {
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
}
}
func TestMap_ToMap(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
regular := m.ToMap()
if len(regular) != 2 {
t.Errorf("expected len 2, got %d", len(regular))
}
if regular["a"] != 1 {
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
}
if regular["b"] != 2 {
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
}
}
func TestMap_NilSafety(t *testing.T) {
var m *Map[string, int]
// All operations should be safe on nil
if m.Len() != 0 {
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on nil map to return false")
}
if v != 0 {
t.Errorf("expected zero value from nil map, got %d", v)
}
// Set on nil is a no-op
m.Set("a", 1)
if m.Len() != 0 {
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
}
// All returns empty iterator
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
if len(keys) != 0 {
t.Errorf("expected empty iteration on nil map, got %v", keys)
}
// ToMap returns nil
if m.ToMap() != nil {
t.Error("expected ToMap to return nil on nil map")
}
// MarshalJSON returns null
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "null" {
t.Errorf("expected null, got %s", string(data))
}
}
func TestMap_EmptyMapMarshal(t *testing.T) {
m := New[string, int]()
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "{}" {
t.Errorf("expected {}, got %s", string(data))
}
}
func TestMap_NestedValues(t *testing.T) {
m := New[string, any]()
m.Set("string", "hello")
m.Set("number", 42)
m.Set("bool", true)
m.Set("nested", map[string]int{"x": 1})
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_AllIteratorEarlyExit(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
m.Set("d", 4)
// Collect only first 2
var keys []string
for k := range m.All() {
keys = append(keys, k)
if len(keys) == 2 {
break
}
}
expected := []string{"a", "b"}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_IntegerKeys(t *testing.T) {
m := New[int, string]()
m.Set(3, "three")
m.Set(1, "one")
m.Set(2, "two")
var keys []int
for k := range m.All() {
keys = append(keys, k)
}
// Should preserve insertion order, not numerical order
expected := []int{3, 1, 2}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_UnmarshalIntoExisting(t *testing.T) {
m := New[string, int]()
m.Set("existing", 999)
// Unmarshal should replace contents
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
_, ok := m.Get("existing")
if ok {
t.Error("existing key should be gone after unmarshal")
}
v, ok := m.Get("new")
if !ok || v != 1 {
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
}
}
func TestMap_LargeOrderPreservation(t *testing.T) {
m := New[string, int]()
// Create many keys in specific order
keys := make([]string, 100)
for i := range 100 {
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
if i >= 26 {
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
}
}
for i, k := range keys {
m.Set(k, i)
}
// Verify order preserved
var resultKeys []string
for k := range m.All() {
resultKeys = append(resultKeys, k)
}
if !slices.Equal(keys, resultKeys) {
t.Error("large map should preserve insertion order")
}
}

View File

@@ -140,6 +140,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
@@ -360,12 +364,15 @@ func roundUp(length, pad int) int {
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*length)
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
@@ -379,7 +386,13 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
}
}
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
if c.config.MaskDType != ml.DTypeF32 {
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "ec98e2002";
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -17,17 +17,11 @@ include /tools/mtmd/clip.cpp
include /tools/mtmd/mtmd.cpp
include /tools/mtmd/mtmd-audio.cpp
include /tools/mtmd/mtmd-helper.cpp
include /tools/mtmd/models/
include /tools/mtmd/models/*.h
include /tools/mtmd/models/*.cpp
include /src/
include /src/llama.*
include /src/llama-*.*
include /src/unicode-data.*
include /src/unicode.*
include /src/models/
include /src/models/*.h
include /src/models/*.cpp
include /vendor/
include /vendor/miniaudio/
include /vendor/miniaudio/*.h

Some files were not shown because too many files have changed in this diff Show More