Compare commits

..

3 Commits

Author SHA1 Message Date
Eva Ho
def9651b55 adding space at the end to push the content up when new message 2025-11-12 22:05:52 -05:00
Eva Ho
dd63aac6bd no auto scrolling when streaming 2025-11-12 18:19:39 -05:00
Eva Ho
78830f6e9d wip 2025-11-12 14:20:46 -05:00
786 changed files with 51194 additions and 109591 deletions

4
.gitattributes vendored
View File

@@ -15,12 +15,8 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -13,7 +13,7 @@ body:
id: logs id: logs
attributes: attributes:
label: Relevant log output label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details. description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell render: shell
validations: validations:
required: false required: false

View File

@@ -16,15 +16,13 @@ jobs:
outputs: outputs:
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }} GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
VERSION: ${{ steps.goflags.outputs.VERSION }} VERSION: ${{ steps.goflags.outputs.VERSION }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set environment - name: Set environment
id: goflags id: goflags
run: | run: |
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
darwin-build: darwin-build:
runs-on: macos-14-xlarge runs-on: macos-14-xlarge
@@ -55,9 +53,6 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- run: | - run: |
./scripts/build_darwin.sh ./scripts/build_darwin.sh
- name: Log build results - name: Log build results
@@ -68,7 +63,6 @@ jobs:
name: bundles-darwin name: bundles-darwin
path: | path: |
dist/*.tgz dist/*.tgz
dist/*.tar.zst
dist/*.zip dist/*.zip
dist/*.dmg dist/*.dmg
@@ -191,7 +185,7 @@ jobs:
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: ${{ github.workspace }}\.ccache path: ${{ github.workspace }}\.ccache
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }} key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}" - name: Build target "${{ matrix.preset }}"
run: | run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
@@ -255,9 +249,6 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- name: Verify gcc is actually clang - name: Verify gcc is actually clang
run: | run: |
$ErrorActionPreference='Continue' $ErrorActionPreference='Continue'
@@ -311,9 +302,6 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
pattern: depends-windows* pattern: depends-windows*
@@ -372,17 +360,12 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }} outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: | - run: |
for COMPONENT in bin/* lib/ollama/*; do for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in case "$COMPONENT" in
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -397,13 +380,13 @@ jobs:
done done
- run: | - run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst); tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done done
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }} name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: | path: |
*.tar.zst *.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower. # Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push: docker-build-push:
@@ -536,7 +519,7 @@ jobs:
- name: Upload release artifacts - name: Upload release artifacts
run: | run: |
pids=() pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload" echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber & gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$! pids[$!]=$!

View File

@@ -22,7 +22,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
changed: ${{ steps.changes.outputs.changed }} changed: ${{ steps.changes.outputs.changed }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
@@ -38,7 +37,6 @@ jobs:
} }
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux: linux:
needs: [changes] needs: [changes]
@@ -85,7 +83,7 @@ jobs:
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: /github/home/.cache/ccache path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: | - run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }} cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel cmake --build --preset ${{ matrix.preset }} --parallel
@@ -180,7 +178,7 @@ jobs:
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: ${{ github.workspace }}\.ccache path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: | - run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
@@ -208,9 +206,6 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: 'go.mod' go-version-file: 'go.mod'
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/setup-node@v4 - uses: actions/setup-node@v4
with: with:
node-version: '20' node-version: '20'
@@ -231,9 +226,12 @@ jobs:
if: always() if: always()
run: go test -count=1 -benchtime=1x ./... run: go test -count=1 -benchtime=1x ./...
- uses: golangci/golangci-lint-action@v9 # TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
with: with:
only-new-issues: true args: --timeout 10m0s -v
patches: patches:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -1,4 +1,5 @@
version: "2" run:
timeout: 5m
linters: linters:
enable: enable:
- asasalint - asasalint
@@ -6,46 +7,35 @@ linters:
- bodyclose - bodyclose
- containedctx - containedctx
- gocheckcompilerdirectives - gocheckcompilerdirectives
- gofmt
- gofumpt
- gosimple
- govet
- ineffassign
- intrange - intrange
- makezero - makezero
- misspell - misspell
- nilerr - nilerr
- nolintlint - nolintlint
- nosprintfhostport - nosprintfhostport
- staticcheck
- unconvert - unconvert
- usetesting - usetesting
- wastedassign - wastedassign
- whitespace - whitespace
disable: disable:
- errcheck
- usestdlibvars - usestdlibvars
settings: - errcheck
govet: linters-settings:
disable:
- unusedresult
staticcheck: staticcheck:
checks: checks:
- all - all
- -QF* # disable quick fix suggestions - -SA1019 # omit Deprecated check
- -SA1019
- -ST1000 # package comment format
- -ST1003 # underscores in package names
- -ST1005 # error strings should not be capitalized
- -ST1012 # error var naming (ErrFoo)
- -ST1016 # receiver name consistency
- -ST1020 # comment on exported function format
- -ST1021 # comment on exported type format
- -ST1022 # comment on exported var format
- -ST1023 # omit type from declaration
severity: severity:
default: error default-severity: error
rules: rules:
- linters: - linters:
- gofmt - gofmt
- goimports - goimports
- intrange - intrange
severity: info severity: info
formatters:
enable:
- gofmt
- gofumpt

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX) project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage) include(CheckLanguage)
include(GNUInstallDirs) include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON) set(GGML_BUILD ON)
set(GGML_SHARED ON) set(GGML_SHARED ON)
@@ -48,10 +32,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON) set(GGML_CPU_ALL_VARIANTS ON)
endif() endif()
if(APPLE) if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path") set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path") set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif() endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -71,13 +54,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
# Define GGML version variables for shared library SOVERSION
# These are required by ggml/src/CMakeLists.txt for proper library versioning
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 0)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
set(GGML_CPU ON) set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -164,7 +140,6 @@ if(CMAKE_HIP_COMPILER)
endif() endif()
endif() endif()
if(NOT APPLE)
find_package(Vulkan) find_package(Vulkan)
if(Vulkan_FOUND) if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
@@ -176,44 +151,3 @@ if(NOT APPLE)
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
) )
endif() endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4", "CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13" "OLLAMA_RUNNER_DIR": "cuda_v13"
} }
}, },
@@ -83,28 +83,6 @@
"cacheVariables": { "cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan" "OLLAMA_RUNNER_DIR": "vulkan"
} }
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
} }
], ],
"buildPresets": [ "buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan", "name": "Vulkan",
"targets": [ "ggml-vulkan" ], "targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan" "configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
} }
] ]
} }

View File

@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future. * New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge. * Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time. * Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
### Issues that may not be accepted ### Issues that may not be accepted
@@ -43,7 +43,7 @@ Tips for proposals:
* Explain how the change will be tested. * Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to Additionally, for bonus points: Provide draft documentation you would expect to
see if the changes were accepted. see if the change were accepted.
## Pull requests ## Pull requests
@@ -66,6 +66,7 @@ Examples:
llm/backend/mlx: support the llama architecture llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad CONTRIBUTING: provide clarity on good commit messages, and bad
docs: simplify manual installation with shorter curl commands
Bad Examples: Bad Examples:

View File

@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION ARG CMAKEVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \ && cmake --build --parallel ${PARALLEL} --preset 'CPU' \
@@ -57,8 +57,6 @@ ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \ cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
@@ -69,8 +67,6 @@ ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \ cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
@@ -82,8 +78,6 @@ ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \ cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
@@ -93,8 +87,6 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS rocm-6 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \ cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
@@ -126,44 +118,11 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan FROM base AS vulkan
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \ cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \ && cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8 && cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
@@ -184,8 +143,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/ COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -204,7 +161,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 FROM ubuntu:24.04
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \ && apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin COPY --from=archive /bin /usr/bin

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=ec98e2002 FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
.PHONY: help .PHONY: help
help: help:
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
$(WORKDIR): $(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR) git clone $(UPSTREAM) $(WORKDIR)
.PHONY: format-patches .PHONE: format-patches
format-patches: llama/patches format-patches: llama/patches
git -C $(WORKDIR) format-patch \ git -C $(WORKDIR) format-patch \
--no-signature \ --no-signature \
@@ -66,11 +66,7 @@ format-patches: llama/patches
-o $(realpath $<) \ -o $(realpath $<) \
$(FETCH_HEAD) $(FETCH_HEAD)
.PHONY: clean .PHONE: clean
clean: checkout clean: checkout
@git -C $(WORKDIR) am --abort || true @git -C $(WORKDIR) am --abort || true
$(RM) llama/patches/.*.patched $(RM) llama/patches/.*.patched
.PHONY: print-base
print-base:
@echo $(FETCH_HEAD)

View File

@@ -299,7 +299,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LibreChat](https://github.com/danny-avila/LibreChat) - [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [AI-UI](https://github.com/bajahaw/ai-ui)
- [Saddle](https://github.com/jikkuatwork/saddle) - [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions) - [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -367,7 +366,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models. - [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding - [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support) - [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -428,7 +426,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.) - [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.) - [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models) - [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare) - [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads) - [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.) - [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
@@ -555,7 +552,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama. - [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples) - [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift) - [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama) - [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify) - [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
@@ -618,7 +615,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality) - [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator) - [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator) - [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper) - [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama) - [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow) - [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language - [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
@@ -626,7 +623,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs) - [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime) - [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama) - [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies) - [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases) - [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network) - [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
@@ -636,12 +633,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama. - [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production. - [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications. - [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications. - [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security ## Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -14,7 +14,7 @@ Please include the following details in your report:
## Security best practices ## Security best practices
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as: While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
- Regularly updating to the latest version of Ollama - Regularly updating to the latest version of Ollama
- Securing access to hosted instances of Ollama - Securing access to hosted instances of Ollama

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 8 * format.MegaByte const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader var buf io.Reader
@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes() bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil { if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest { return fmt.Errorf("unmarshal: %w", err)
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
} }
if response.StatusCode == http.StatusUnauthorized { if response.StatusCode == http.StatusUnauthorized {
@@ -347,7 +340,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that // Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]). // behaves similarly to other methods (see [Client.Pull]).
// //
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx // [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse var resp ProgressResponse

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct { type testError struct {
message string message string
statusCode int statusCode int
raw bool // if true, write message as-is instead of JSON encoding
} }
func (e testError) Error() string { func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
}, },
}, },
}, },
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return return
} }
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil { if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err) t.Fatalf("failed to encode response: %v", err)
} }
@@ -197,7 +176,6 @@ func TestClientDo(t *testing.T) {
name string name string
response any response any
wantErr string wantErr string
wantStatusCode int
}{ }{
{ {
name: "immediate error response", name: "immediate error response",
@@ -206,7 +184,6 @@ func TestClientDo(t *testing.T) {
statusCode: http.StatusBadRequest, statusCode: http.StatusBadRequest,
}, },
wantErr: "test error message", wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
}, },
{ {
name: "server error response", name: "server error response",
@@ -215,7 +192,6 @@ func TestClientDo(t *testing.T) {
statusCode: http.StatusInternalServerError, statusCode: http.StatusInternalServerError,
}, },
wantErr: "internal error", wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
}, },
{ {
name: "successful response", name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true, Success: true,
}, },
}, },
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@@ -254,17 +210,12 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok { if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode) w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{ err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message, "error": errResp.message,
}) })
if err != nil { if err != nil {
t.Fatal("failed to encode error response:", err) t.Fatal("failed to encode error response:", err)
} }
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
}
return return
} }
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr { if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr) t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
} }
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return return
} }

View File

@@ -15,19 +15,19 @@ func main() {
} }
messages := []api.Message{ messages := []api.Message{
{ api.Message{
Role: "system", Role: "system",
Content: "Provide very brief, concise responses", Content: "Provide very brief, concise responses",
}, },
{ api.Message{
Role: "user", Role: "user",
Content: "Name some unusual animals", Content: "Name some unusual animals",
}, },
{ api.Message{
Role: "assistant", Role: "assistant",
Content: "Monotreme, platypus, echidna", Content: "Monotreme, platypus, echidna",
}, },
{ api.Message{
Role: "user", Role: "user",
Content: "which of these is the most dangerous?", Content: "which of these is the most dangerous?",
}, },

View File

@@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"iter"
"log/slog" "log/slog"
"math" "math"
"os" "os"
@@ -15,7 +14,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -229,79 +227,13 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"` Arguments ToolCallFunctionArguments `json:"arguments"`
} }
// ToolCallFunctionArguments holds tool call arguments in insertion order. type ToolCallFunctionArguments map[string]any
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t *ToolCallFunctionArguments) String() string { func (t *ToolCallFunctionArguments) String() string {
if t == nil || t.om == nil { bts, _ := json.Marshal(t)
return "{}"
}
bts, _ := json.Marshal(t.om)
return string(bts) return string(bts)
} }
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct { type Tool struct {
Type string `json:"type"` Type string `json:"type"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
@@ -350,78 +282,12 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt)) return fmt.Sprintf("%v", []string(pt))
} }
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct { type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"` AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"` Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"` Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
} }
// ToTypeScriptType converts a ToolProperty to a TypeScript type string // ToTypeScriptType converts a ToolProperty to a TypeScript type string
@@ -474,7 +340,7 @@ type ToolFunctionParameters struct {
Defs any `json:"$defs,omitempty"` Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"` Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"` Properties map[string]ToolProperty `json:"properties"`
} }
func (t *ToolFunctionParameters) String() string { func (t *ToolFunctionParameters) String() string {
@@ -500,9 +366,6 @@ type TokenLogprob struct {
// Logprob is the log probability of this token. // Logprob is the log probability of this token.
Logprob float64 `json:"logprob"` Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token
Bytes []int `json:"bytes,omitempty"`
} }
// Logprob contains log probability information for a generated token. // Logprob contains log probability information for a generated token.
@@ -687,9 +550,6 @@ type CreateRequest struct {
Renderer string `json:"renderer,omitempty"` Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"` Parser string `json:"parser,omitempty"`
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Info is a map of additional information for the model // Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"` Info map[string]any `json:"info,omitempty"`
@@ -740,7 +600,6 @@ type ShowResponse struct {
Tensors []Tensor `json:"tensors,omitempty"` Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"` Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
} }
// CopyRequest is the request passed to [Client.Copy]. // CopyRequest is the request passed to [Client.Copy].

View File

@@ -11,24 +11,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) { func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -327,9 +309,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
input: ToolFunctionParameters{ input: ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"name"}, Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}}, "name": {Type: PropertyType{"string"}},
}), },
}, },
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
}, },
@@ -337,9 +319,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
name: "no required", name: "no required",
input: ToolFunctionParameters{ input: ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}}, "name": {Type: PropertyType{"string"}},
}), },
}, },
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`, expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
}, },
@@ -357,7 +339,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) { func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{ fn := ToolCallFunction{
Name: "echo", Name: "echo",
Arguments: testArgs(map[string]any{"message": "hi"}), Arguments: ToolCallFunctionArguments{"message": "hi"},
} }
data, err := json.Marshal(fn) data, err := json.Marshal(fn)
@@ -522,116 +504,6 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
} }
} }
func TestToolPropertyNestedProperties(t *testing.T) {
tests := []struct {
name string
input string
expected ToolProperty
}{
{
name: "nested object properties",
input: `{
"type": "object",
"description": "Location details",
"properties": {
"address": {
"type": "string",
"description": "Street address"
},
"city": {
"type": "string",
"description": "City name"
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{
"address": {
Type: PropertyType{"string"},
Description: "Street address",
},
"city": {
Type: PropertyType{"string"},
Description: "City name",
},
}),
},
},
{
name: "deeply nested properties",
input: `{
"type": "object",
"description": "Event",
"properties": {
"location": {
"type": "object",
"description": "Location",
"properties": {
"coordinates": {
"type": "object",
"description": "GPS coordinates",
"properties": {
"lat": {"type": "number", "description": "Latitude"},
"lng": {"type": "number", "description": "Longitude"}
}
}
}
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{
"location": {
Type: PropertyType{"object"},
Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{
"coordinates": {
Type: PropertyType{"object"},
Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}),
},
}),
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
// Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop)
require.NoError(t, err)
var prop2 ToolProperty
err = json.Unmarshal(data, &prop2)
require.NoError(t, err)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
})
}
}
func TestToolFunctionParameters_String(t *testing.T) { func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -643,12 +515,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{ params: ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"name"}, Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": { "name": {
Type: PropertyType{"string"}, Type: PropertyType{"string"},
Description: "The name of the person", Description: "The name of the person",
}, },
}), },
}, },
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
}, },
@@ -665,7 +537,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s s.Self = s
return s return s
}(), }(),
Properties: testPropsMap(map[string]ToolProperty{}), Properties: map[string]ToolProperty{},
}, },
expected: "", expected: "",
}, },
@@ -678,235 +550,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
}) })
} }
} }
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@@ -273,6 +273,10 @@ func main() {
Handler: uiServer.Handler(), Handler: uiServer.Handler(),
} }
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
// Start the UI server // Start the UI server
slog.Info("starting ui server", "port", port) slog.Info("starting ui server", "port", port)
go func() { go func() {
@@ -316,17 +320,6 @@ func main() {
slog.Debug("no URL scheme request to handle") slog.Debug("no URL scheme request to handle")
} }
go func() {
slog.Debug("waiting for ollama server to be ready")
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
slog.Warn("ollama server not ready, continuing anyway", "error", err)
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
}()
osRun(cancel, hasCompletedFirstRun, startHidden) osRun(cancel, hasCompletedFirstRun, startHidden)
slog.Info("shutting down desktop server") slog.Info("shutting down desktop server")
@@ -368,7 +361,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
return false return false
} }
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil) resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
if err != nil { if err != nil {
slog.Debug("failed to call local auth endpoint", "error", err) slog.Debug("failed to call local auth endpoint", "error", err)
return false return false
@@ -404,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser // handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() { func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) { if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening app instead") slog.Info("user is already logged in, opening settings instead")
showWindow(wv.webview.Window()) sendUIRequestMessage("/")
return return
} }
@@ -441,30 +434,37 @@ func openInBrowser(url string) {
} }
} }
// parseURLScheme parses an ollama:// URL and validates it // parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
// Supports: ollama:// (open app) and ollama://connect (OAuth) func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
parsedURL, err := url.Parse(urlSchemeRequest) parsedURL, err := url.Parse(urlSchemeRequest)
if err != nil { if err != nil {
return false, fmt.Errorf("invalid URL: %w", err) return false, "", err
} }
// Check if this is a connect URL // Check if this is a connect URL
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" { if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
return true, nil return true, "", nil
} }
// Allow bare ollama:// or ollama:/// to open the app // Extract the UI path
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" { path := "/"
return false, nil if parsedURL.Path != "" && parsedURL.Path != "/" {
// For URLs like ollama:///settings, use the path directly
path = parsedURL.Path
} else if parsedURL.Host != "" {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
// This also handles ollama://settings/ where Windows adds a trailing slash
path = "/" + parsedURL.Host
} }
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest) return false, path, nil
} }
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance // handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) { func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
isConnect, err := parseURLScheme(urlSchemeRequest) isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
if err != nil { if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err) slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
return return
@@ -473,8 +473,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect { if isConnect {
handleConnectURLScheme() handleConnectURLScheme()
} else { } else {
if wv.webview != nil { sendUIRequestMessage(uiPath)
showWindow(wv.webview.Window())
}
} }
} }

View File

@@ -191,6 +191,13 @@ func LaunchNewApp() {
C.launchApp(appName) C.launchApp(appName)
} }
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) { func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations // Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem() C.unregisterSelfFromLoginItem()

View File

@@ -24,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) { for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) { if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path; NSString *path = url.path;
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) { if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
// Special case: handle connect by opening browser instead of app // Special case: handle connect by opening browser instead of app
handleConnectURL(); handleConnectURL();
} else { } else {
// Set app to be active and visible // Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular]; [NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES]; [NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
} }
break; break;
@@ -247,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
} }
- (void)openHelp:(id)sender { - (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"]; NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
[[NSWorkspace sharedWorkspace] openURL:url]; [[NSWorkspace sharedWorkspace] openURL:url];
} }

View File

@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
// handleURLSchemeRequest processes URL scheme requests from other instances // handleURLSchemeRequest processes URL scheme requests from other instances
func handleURLSchemeRequest(urlScheme string) { func handleURLSchemeRequest(urlScheme string) {
isConnect, err := parseURLScheme(urlScheme) isConnect, uiPath, err := parseURLScheme(urlScheme)
if err != nil { if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err) slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
return return
@@ -147,9 +147,7 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect { if isConnect {
handleConnectURLScheme() handleConnectURLScheme()
} else { } else {
if wv.webview != nil { sendUIRequestMessage(uiPath)
showWindow(wv.webview.Window())
}
} }
} }
@@ -263,6 +261,11 @@ func createLoginShortcut() error {
return nil return nil
} }
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() { func LaunchNewApp() {
} }

View File

@@ -169,11 +169,7 @@ DlgResult fileDlg(FileDlgParams* params) {
} }
NSArray* urls = [panel URLs]; NSArray* urls = [panel URLs];
if([urls count] == 0) { if(self->params->allowMultiple && [urls count] >= 1) {
return DLG_CANCEL;
}
if(self->params->allowMultiple) {
// For multiple files, we need to return all paths separated by null bytes // For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf; char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf; int remainingBuf = self->params->nbuf;
@@ -204,12 +200,6 @@ DlgResult fileDlg(FileDlgParams* params) {
bufPtr += pathLen + 1; bufPtr += pathLen + 1;
} }
*bufPtr = '\0'; // Final null terminator *bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
} }
return DLG_OK; return DLG_OK;

View File

@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int type WinDlgError int
func (e WinDlgError) Error() string { func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e)) return fmt.Sprintf("CommDlgExtendedError: %#x", e)
} }
func err() error { func err() error {

View File

@@ -224,7 +224,9 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil { if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models env["OLLAMA_MODELS"] = settings.Models
} else { } else {
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err) slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
} }
} }
if settings.ContextLength > 0 { if settings.ContextLength > 0 {

View File

@@ -469,24 +469,26 @@ export class HealthResponse {
} }
export class User { export class User {
id: string; id: string;
email: string;
name: string; name: string;
bio?: string; email: string;
avatarurl?: string; avatarURL: string;
firstname?: string; plan: string;
lastname?: string; bio: string;
plan?: string; firstName: string;
lastName: string;
overThreshold: boolean;
constructor(source: any = {}) { constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source); if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"]; this.id = source["id"];
this.email = source["email"];
this.name = source["name"]; this.name = source["name"];
this.bio = source["bio"]; this.email = source["email"];
this.avatarurl = source["avatarurl"]; this.avatarURL = source["avatarURL"];
this.firstname = source["firstname"];
this.lastname = source["lastname"];
this.plan = source["plan"]; this.plan = source["plan"];
this.bio = source["bio"];
this.firstName = source["firstName"];
this.lastName = source["lastName"];
this.overThreshold = source["overThreshold"];
} }
} }
export class Attachment { export class Attachment {

View File

@@ -27,7 +27,8 @@
"remark-math": "^6.0.0", "remark-math": "^6.0.0",
"streamdown": "^1.4.0", "streamdown": "^1.4.0",
"unist-builder": "^4.0.0", "unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0" "unist-util-parents": "^3.0.0",
"use-stick-to-bottom": "^1.1.1"
}, },
"devDependencies": { "devDependencies": {
"@chromatic-com/storybook": "^4.0.1", "@chromatic-com/storybook": "^4.0.1",
@@ -12739,6 +12740,15 @@
"punycode": "^2.1.0" "punycode": "^2.1.0"
} }
}, },
"node_modules/use-stick-to-bottom": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/use-stick-to-bottom/-/use-stick-to-bottom-1.1.1.tgz",
"integrity": "sha512-JkDp0b0tSmv7HQOOpL1hT7t7QaoUBXkq045WWWOFDTlLGRzgIIyW7vyzOIJzY7L2XVIG7j1yUxeDj2LHm9Vwng==",
"license": "MIT",
"peerDependencies": {
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/use-sync-external-store": { "node_modules/use-sync-external-store": {
"version": "1.5.0", "version": "1.5.0",
"resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.5.0.tgz", "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.5.0.tgz",

View File

@@ -36,7 +36,8 @@
"remark-math": "^6.0.0", "remark-math": "^6.0.0",
"streamdown": "^1.4.0", "streamdown": "^1.4.0",
"unist-builder": "^4.0.0", "unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0" "unist-util-parents": "^3.0.0",
"use-stick-to-bottom": "^1.1.1"
}, },
"devDependencies": { "devDependencies": {
"@chromatic-com/storybook": "^4.0.1", "@chromatic-com/storybook": "^4.0.1",

View File

@@ -15,7 +15,6 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing"; import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client"; import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser"; import type { ModelResponse } from "ollama/browser";
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
// Extend Model class with utility methods // Extend Model class with utility methods
declare module "@/gotypes" { declare module "@/gotypes" {
@@ -27,6 +26,9 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean { Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud"); return this.model.endsWith("cloud");
}; };
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
// Helper function to convert Uint8Array to base64 // Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string { function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -41,8 +43,9 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
} }
export async function fetchUser(): Promise<User | null> { export async function fetchUser(): Promise<User | null> {
const response = await fetch(`${API_BASE}/api/me`, { try {
method: "POST", const response = await fetch(`${API_BASE}/api/v1/me`, {
method: "GET",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
@@ -50,41 +53,34 @@ export async function fetchUser(): Promise<User | null> {
if (response.ok) { if (response.ok) {
const userData: User = await response.json(); const userData: User = await response.json();
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
}
return userData; return userData;
} }
if (response.status === 401 || response.status === 403) { return null;
} catch (error) {
console.error("Error fetching user:", error);
return null; return null;
} }
throw new Error(`Failed to fetch user: ${response.status}`);
} }
export async function fetchConnectUrl(): Promise<string> { export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/me`, { const response = await fetch(`${API_BASE}/api/v1/connect`, {
method: "POST", method: "GET",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
}); });
if (response.status === 401) { if (!response.ok) {
const data = await response.json();
if (data.signin_url) {
return data.signin_url;
}
}
throw new Error("Failed to fetch connect URL"); throw new Error("Failed to fetch connect URL");
} }
const data = await response.json();
return data.connect_url;
}
export async function disconnectUser(): Promise<void> { export async function disconnectUser(): Promise<void> {
const response = await fetch(`${API_BASE}/api/signout`, { const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -209,10 +205,12 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data), data: uint8ArrayToBase64(att.data),
})); }));
// Send think parameter when it's explicitly set (true, false, or a non-empty string). // Only send think parameter when actually requesting thinking
// Don't send false as it causes issues with some providers
const shouldSendThink = const shouldSendThink =
think !== undefined && think !== undefined &&
(typeof think === "boolean" || (typeof think === "string" && think !== "")); ((typeof think === "boolean" && think) ||
(typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, { const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST", method: "POST",
@@ -394,8 +392,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function fetchHealth(): Promise<boolean> { export async function fetchHealth(): Promise<boolean> {
try { try {
// Use the /api/version endpoint as a health check const response = await fetch(`${API_BASE}/api/v1/health`, {
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET", method: "GET",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -404,8 +401,7 @@ export async function fetchHealth(): Promise<boolean> {
if (response.ok) { if (response.ok) {
const data = await response.json(); const data = await response.json();
// If we get a version back, the server is healthy return data.healthy || false;
return !!data.version;
} }
return false; return false;

View File

@@ -4,6 +4,11 @@ import { FileUpload } from "./FileUpload";
import { DisplayUpgrade } from "./DisplayUpgrade"; import { DisplayUpgrade } from "./DisplayUpgrade";
import { DisplayStale } from "./DisplayStale"; import { DisplayStale } from "./DisplayStale";
import { DisplayLogin } from "./DisplayLogin"; import { DisplayLogin } from "./DisplayLogin";
import {
ConversationContent,
ConversationScrollButton,
ConversationWithSpacer,
} from "./ai-elements/conversation";
import { import {
useChat, useChat,
useSendMessage, useSendMessage,
@@ -15,14 +20,7 @@ import {
useDismissStaleModel, useDismissStaleModel,
} from "@/hooks/useChats"; } from "@/hooks/useChats";
import { useHealth } from "@/hooks/useHealth"; import { useHealth } from "@/hooks/useHealth";
import { useMessageAutoscroll } from "@/hooks/useMessageAutoscroll"; import { useState, useEffect, useRef, useCallback } from "react";
import {
useState,
useEffect,
useLayoutEffect,
useRef,
useCallback,
} from "react";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useNavigate } from "@tanstack/react-router"; import { useNavigate } from "@tanstack/react-router";
import { useSelectedModel } from "@/hooks/useSelectedModel"; import { useSelectedModel } from "@/hooks/useSelectedModel";
@@ -47,7 +45,6 @@ export default function Chat({ chatId }: { chatId: string }) {
index: number; index: number;
originalMessage: Message; originalMessage: Message;
} | null>(null); } | null>(null);
const prevChatIdRef = useRef<string>(chatId);
const chatFormCallbackRef = useRef< const chatFormCallbackRef = useRef<
| (( | ((
@@ -102,29 +99,8 @@ export default function Chat({ chatId }: { chatId: string }) {
const sendMessageMutation = useSendMessage(chatId); const sendMessageMutation = useSendMessage(chatId);
const { containerRef, handleNewUserMessage, spacerHeight } = const latestMessageRef = useRef<HTMLDivElement>(null);
useMessageAutoscroll({
messages,
isStreaming,
chatId,
});
// Scroll to bottom only when switching to a different existing chat
useLayoutEffect(() => {
// Only scroll if the chatId actually changed (not just messages updating)
if (
prevChatIdRef.current !== chatId &&
containerRef.current &&
messages.length > 0 &&
chatId !== "new"
) {
// Always scroll to the bottom when opening a chat
containerRef.current.scrollTop = containerRef.current.scrollHeight;
}
prevChatIdRef.current = chatId;
}, [chatId, messages.length]);
// Simplified submit handler - ChatForm handles all the attachment logic
const handleChatFormSubmit = ( const handleChatFormSubmit = (
message: string, message: string,
options: { options: {
@@ -168,7 +144,6 @@ export default function Chat({ chatId }: { chatId: string }) {
// Clear edit mode after submission // Clear edit mode after submission
setEditingMessage(null); setEditingMessage(null);
handleNewUserMessage();
}; };
const handleEditMessage = (content: string, index: number) => { const handleEditMessage = (content: string, index: number) => {
@@ -193,8 +168,6 @@ export default function Chat({ chatId }: { chatId: string }) {
); );
}; };
const isWindows = navigator.platform.toLowerCase().includes("win");
return chatId === "new" || chatQuery ? ( return chatId === "new" || chatQuery ? (
<FileUpload <FileUpload
onFilesAdded={handleFilesProcessed} onFilesAdded={handleFilesProcessed}
@@ -219,14 +192,15 @@ export default function Chat({ chatId }: { chatId: string }) {
</div> </div>
) : ( ) : (
<main className="flex h-screen w-full flex-col relative allow-context-menu select-none"> <main className="flex h-screen w-full flex-col relative allow-context-menu select-none">
<section <ConversationWithSpacer
key={chatId} // This key forces React to recreate the element when chatId changes key={chatId}
ref={containerRef} className={`flex-1 overscroll-contain select-none`}
className={`flex-1 overflow-y-auto overscroll-contain relative min-h-0 select-none ${isWindows ? "xl:pt-4" : "xl:pt-8"}`} isStreaming={isStreaming}
messageCount={messages.length}
> >
<ConversationContent isStreaming={isStreaming}>
<MessageList <MessageList
messages={messages} messages={messages}
spacerHeight={spacerHeight}
isWaitingForLoad={isWaitingForLoad} isWaitingForLoad={isWaitingForLoad}
isStreaming={isStreaming} isStreaming={isStreaming}
downloadProgress={downloadProgress} downloadProgress={downloadProgress}
@@ -236,8 +210,11 @@ export default function Chat({ chatId }: { chatId: string }) {
editingMessageIndex={editingMessage?.index} editingMessageIndex={editingMessage?.index}
error={chatError} error={chatError}
browserToolResult={browserToolResult} browserToolResult={browserToolResult}
latestMessageRef={latestMessageRef}
/> />
</section> </ConversationContent>
<ConversationScrollButton />
</ConversationWithSpacer>
<div className="flex-shrink-0 sticky bottom-0 z-20"> <div className="flex-shrink-0 sticky bottom-0 z-20">
{selectedModel && shouldShowStaleDisplay && ( {selectedModel && shouldShowStaleDisplay && (
@@ -248,14 +225,6 @@ export default function Chat({ chatId }: { chatId: string }) {
dismissStaleModel(selectedModel?.model || "") dismissStaleModel(selectedModel?.model || "")
} }
chatId={chatId} chatId={chatId}
onScrollToBottom={() => {
if (containerRef.current) {
containerRef.current.scrollTo({
top: containerRef.current.scrollHeight,
behavior: "smooth",
});
}
}}
/> />
</div> </div>
)} )}

View File

@@ -3,10 +3,10 @@ import React from "react";
import Message from "./Message"; import Message from "./Message";
import Downloading from "./Downloading"; import Downloading from "./Downloading";
import { ErrorMessage } from "./ErrorMessage"; import { ErrorMessage } from "./ErrorMessage";
import { ConversationSpacer } from "./ai-elements/conversation";
export default function MessageList({ export default function MessageList({
messages, messages,
spacerHeight,
isWaitingForLoad, isWaitingForLoad,
isStreaming, isStreaming,
downloadProgress, downloadProgress,
@@ -14,9 +14,9 @@ export default function MessageList({
editingMessageIndex, editingMessageIndex,
error, error,
browserToolResult, browserToolResult,
latestMessageRef,
}: { }: {
messages: MessageType[]; messages: MessageType[];
spacerHeight: number;
isWaitingForLoad?: boolean; isWaitingForLoad?: boolean;
isStreaming: boolean; isStreaming: boolean;
downloadProgress?: DownloadEvent; downloadProgress?: DownloadEvent;
@@ -24,6 +24,7 @@ export default function MessageList({
editingMessageIndex?: number; editingMessageIndex?: number;
error?: ErrorEvent | null; error?: ErrorEvent | null;
browserToolResult?: any; browserToolResult?: any;
latestMessageRef?: React.RefObject<HTMLDivElement | null>;
}) { }) {
const [showDots, setShowDots] = React.useState(false); const [showDots, setShowDots] = React.useState(false);
const isDownloadingModel = downloadProgress && !downloadProgress.done; const isDownloadingModel = downloadProgress && !downloadProgress.done;
@@ -84,13 +85,19 @@ export default function MessageList({
return ( return (
<div <div
className="mx-auto flex max-w-[768px] flex-1 flex-col px-6 pb-12 select-text" className="mx-auto flex max-w-[768px] w-full flex-1 flex-col px-6 py-12 select-text"
data-role="message-list" data-role="message-list"
> >
{messages.map((message, idx) => { {messages.map((message, idx) => {
const lastToolQuery = lastToolQueries[idx]; const lastToolQuery = lastToolQueries[idx];
const isLastMessage = idx === messages.length - 1;
return ( return (
<div key={`${message.created_at}-${idx}`} data-message-index={idx}> <div
key={`${message.created_at}-${idx}`}
data-message-index={idx}
data-message-role={message.role}
ref={isLastMessage ? latestMessageRef : null}
>
<Message <Message
message={message} message={message}
onEditMessage={onEditMessage} onEditMessage={onEditMessage}
@@ -161,8 +168,8 @@ export default function MessageList({
</section> </section>
)} )}
{/* Dynamic spacer to allow scrolling the last message to the top of the container */} {/* Dynamic spacer */}
<div style={{ height: `${spacerHeight}px` }} aria-hidden="true" /> <ConversationSpacer />
</div> </div>
); );
} }

View File

@@ -299,9 +299,9 @@ export default function Settings() {
</Button> </Button>
</div> </div>
</div> </div>
{user?.avatarurl && ( {user?.avatarURL && (
<img <img
src={user.avatarurl} src={user.avatarURL}
alt={user?.name} alt={user?.name}
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0" className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
onError={(e) => { onError={(e) => {

View File

@@ -50,9 +50,6 @@ export default function Thinking({
// Position content to show bottom when collapsed // Position content to show bottom when collapsed
useEffect(() => { useEffect(() => {
if (isCollapsed && contentRef.current && wrapperRef.current) { if (isCollapsed && contentRef.current && wrapperRef.current) {
requestAnimationFrame(() => {
if (!contentRef.current || !wrapperRef.current) return;
const contentHeight = contentRef.current.scrollHeight; const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight; const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) { if (contentHeight > wrapperHeight) {
@@ -60,23 +57,14 @@ export default function Thinking({
contentRef.current.style.transform = `translateY(${translateY}px)`; contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true); setHasOverflow(true);
} else { } else {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false); setHasOverflow(false);
} }
});
} else if (contentRef.current) { } else if (contentRef.current) {
contentRef.current.style.transform = "translateY(0)"; contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false); setHasOverflow(false);
} }
}, [thinking, isCollapsed]); }, [thinking, isCollapsed]);
useEffect(() => {
if (activelyThinking && wrapperRef.current && !isCollapsed) {
// When expanded and actively thinking, scroll to bottom
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
}
}, [thinking, activelyThinking, isCollapsed]);
const handleToggle = () => { const handleToggle = () => {
setIsCollapsed(!isCollapsed); setIsCollapsed(!isCollapsed);
setHasUserInteracted(true); setHasUserInteracted(true);

View File

@@ -0,0 +1,385 @@
"use client";
import clsx from "clsx";
import type { ComponentProps } from "react";
import {
useCallback,
useEffect,
useRef,
createContext,
useContext,
useState,
} from "react";
import { StickToBottom, useStickToBottomContext } from "use-stick-to-bottom";
// Create a context to share the "allow scroll" state and spacer state
const ConversationControlContext = createContext<{
allowScroll: () => void;
spacerHeight: number;
setSpacerHeight: (height: number) => void;
} | null>(null);
export type ConversationProps = ComponentProps<typeof StickToBottom> & {
isStreaming?: boolean;
};
export const Conversation = ({
className,
isStreaming = false,
children,
...props
}: ConversationProps) => {
const shouldStopScrollRef = useRef(true);
const [spacerHeight, setSpacerHeight] = useState(0);
const allowScroll = useCallback(() => {
shouldStopScrollRef.current = false;
setTimeout(() => {
shouldStopScrollRef.current = true;
}, 100);
}, []);
return (
<ConversationControlContext.Provider
value={{ allowScroll, spacerHeight, setSpacerHeight }}
>
<StickToBottom
className={clsx("relative h-full w-full overflow-y-auto", className)}
initial="instant"
resize="instant"
role="log"
{...props}
>
<ConversationContentInternal
isStreaming={isStreaming}
shouldStopScrollRef={shouldStopScrollRef}
>
<>{children}</>
</ConversationContentInternal>
</StickToBottom>
</ConversationControlContext.Provider>
);
};
// New wrapper component that includes spacer management
export const ConversationWithSpacer = ({
isStreaming,
messageCount,
children,
...props
}: ConversationProps & {
messageCount: number;
}) => {
return (
<Conversation isStreaming={isStreaming} {...props}>
<SpacerController
isStreaming={isStreaming ?? false}
messageCount={messageCount}
/>
<>{children}</>
</Conversation>
);
};
// This component manages the spacer state but doesn't render the spacer itself
const SpacerController = ({
isStreaming,
messageCount,
}: {
isStreaming: boolean;
messageCount: number;
}) => {
const context = useContext(ConversationControlContext);
const { scrollToBottom } = useStickToBottomContext();
const previousMessageCountRef = useRef(messageCount);
const scrollContainerRef = useRef<HTMLElement | null>(null);
const [isActiveInteraction, setIsActiveInteraction] = useState(false);
// Get reference to scroll container
useEffect(() => {
const container = document.querySelector('[role="log"]') as HTMLElement;
scrollContainerRef.current = container;
}, []);
// Calculate spacer height based on actual DOM elements
const calculateSpacerHeight = useCallback(() => {
const container = scrollContainerRef.current;
if (!container) {
console.log("❌ No container");
return 0;
}
const containerHeight = container.clientHeight;
// Find all message elements
const messageElements = container.querySelectorAll(
"[data-message-index]",
) as NodeListOf<HTMLElement>;
console.log("📝 Found", messageElements.length, "message elements");
if (messageElements.length === 0) return 0;
// Log all messages and their roles
messageElements.forEach((el, i) => {
const role = el.getAttribute("data-message-role");
console.log(` Message ${i}: role="${role}"`);
});
// Find the last user message
let lastUserMessageElement: HTMLElement | null = null;
let lastUserMessageIndex = -1;
for (let i = messageElements.length - 1; i >= 0; i--) {
const el = messageElements[i];
const role = el.getAttribute("data-message-role");
if (role === "user") {
lastUserMessageElement = el;
lastUserMessageIndex = i;
console.log("✅ Found user message at index:", i);
break;
}
}
if (!lastUserMessageElement) {
console.log("❌ No user message found!");
return 0;
}
// Calculate height of content after the last user message
let contentHeightAfter = 0;
for (let i = lastUserMessageIndex + 1; i < messageElements.length; i++) {
contentHeightAfter += messageElements[i].offsetHeight;
}
const userMessageHeight = lastUserMessageElement.offsetHeight;
// Goal: Position user message at the top with some padding
// We want the user message to start at around 10% from the top of viewport
const targetTopPosition = containerHeight * 0.05; // 10% from top
// Calculate spacer: we need enough space so that when scrolled to bottom:
// spacerHeight = containerHeight - targetTopPosition - userMessageHeight - contentAfter
const calculatedHeight =
containerHeight -
targetTopPosition -
userMessageHeight -
contentHeightAfter;
const baseHeight = Math.max(0, calculatedHeight);
console.log(
"📊 Container:",
containerHeight,
"User msg:",
userMessageHeight,
"Content after:",
contentHeightAfter,
"Target top pos:",
targetTopPosition,
"→ Final spacer:",
baseHeight,
);
return baseHeight;
}, []);
// When a new message is submitted, set initial spacer height and scroll
useEffect(() => {
if (messageCount > previousMessageCountRef.current) {
console.log("🎯 NEW MESSAGE - Setting spacer");
// Allow scrolling by temporarily disabling stopScroll
context?.allowScroll();
setIsActiveInteraction(true);
const container = scrollContainerRef.current;
if (container) {
// Wait for new message to render in DOM
requestAnimationFrame(() => {
requestAnimationFrame(() => {
const spacerHeight = calculateSpacerHeight();
console.log("📏 Calculated spacer:", spacerHeight);
context?.setSpacerHeight(spacerHeight);
// Wait for spacer to be added to DOM, then scroll
requestAnimationFrame(() => {
requestAnimationFrame(() => {
// Use the library's scrollToBottom method
console.log("📜 Scrolling to bottom");
scrollToBottom("instant");
console.log("📜 Final scrollTop:", container.scrollTop);
});
});
});
});
}
}
previousMessageCountRef.current = messageCount;
}, [messageCount, context, calculateSpacerHeight, scrollToBottom]);
// Update active interaction state
useEffect(() => {
if (isStreaming) {
setIsActiveInteraction(true);
}
// Don't automatically set to false when streaming stops
// Let the ResizeObserver handle clearing the spacer naturally
}, [isStreaming]);
// Use ResizeObserver to recalculate spacer as content changes
useEffect(() => {
const container = scrollContainerRef.current;
if (!container || !isActiveInteraction) return;
const resizeObserver = new ResizeObserver(() => {
const newHeight = calculateSpacerHeight();
context?.setSpacerHeight(newHeight);
// Clear active interaction when spacer reaches 0
if (newHeight === 0) {
setIsActiveInteraction(false);
}
});
// Observe all message elements
const messageElements = container.querySelectorAll("[data-message-index]");
messageElements.forEach((element) => {
resizeObserver.observe(element);
});
return () => {
resizeObserver.disconnect();
};
}, [isActiveInteraction, calculateSpacerHeight, context]);
// Remove the effect that clears spacer when not streaming
// This was causing the spacer to disappear prematurely
return null;
};
const ConversationContentInternal = ({
isStreaming,
shouldStopScrollRef,
children,
}: {
isStreaming: boolean;
shouldStopScrollRef: React.MutableRefObject<boolean>;
children: React.ReactNode;
}) => {
const { stopScroll } = useStickToBottomContext();
const stopScrollRef = useRef(stopScroll);
useEffect(() => {
stopScrollRef.current = stopScroll;
}, [stopScroll]);
useEffect(() => {
if (!isStreaming) return;
const interval = setInterval(() => {
if (shouldStopScrollRef.current) {
stopScrollRef.current();
}
}, 16);
if (shouldStopScrollRef.current) {
stopScrollRef.current();
}
return () => clearInterval(interval);
}, [isStreaming, shouldStopScrollRef]);
return <>{children}</>;
};
export type ConversationContentProps = ComponentProps<
typeof StickToBottom.Content
> & {
isStreaming?: boolean;
};
export const ConversationContent = ({
className,
children,
...props
}: ConversationContentProps) => {
return (
<StickToBottom.Content
className={clsx("flex flex-col", className)}
{...props}
>
{children}
</StickToBottom.Content>
);
};
// Spacer component that can be placed anywhere in your content
export const ConversationSpacer = () => {
const context = useContext(ConversationControlContext);
const spacerHeight = context?.spacerHeight ?? 0;
console.log("🎨 Spacer render - height:", spacerHeight);
if (spacerHeight === 0) return null;
return (
<div
style={{
height: `${spacerHeight}px`,
flexShrink: 0,
backgroundColor: "rgba(255,0,0,0.1)", // Temporary for debugging
}}
aria-hidden="true"
/>
);
};
export type ConversationScrollButtonProps = ComponentProps<"button">;
export const ConversationScrollButton = ({
className,
...props
}: ConversationScrollButtonProps) => {
const { isAtBottom, scrollToBottom } = useStickToBottomContext();
const context = useContext(ConversationControlContext);
const handleScrollToBottom = useCallback(() => {
context?.allowScroll();
scrollToBottom();
}, [scrollToBottom, context]);
// Show button if not at bottom AND spacer is not active (height is 0)
const shouldShowButton = !isAtBottom && (context?.spacerHeight ?? 0) === 0;
return (
shouldShowButton && (
<button
className={clsx(
"absolute bottom-4 left-[50%] translate-x-[-50%] rounded-full z-50",
"bg-white dark:bg-neutral-800 border border-neutral-200 dark:border-neutral-700",
"p-1 shadow-lg hover:shadow-xl transition-all",
"text-neutral-700 dark:text-neutral-200 hover:scale-105",
"hover:cursor-pointer",
className,
)}
onClick={handleScrollToBottom}
type="button"
aria-label="Scroll to bottom"
{...props}
>
<svg
className="w-5 h-5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<polyline points="6 9 12 15 18 9" />
</svg>
</button>
)
);
};

View File

@@ -0,0 +1,8 @@
export {
Conversation,
ConversationContent,
ConversationScrollButton,
type ConversationProps,
type ConversationContentProps,
type ConversationScrollButtonProps,
} from "./conversation";

View File

@@ -7,7 +7,6 @@ import { createQueryBatcher } from "./useQueryBatcher";
import { useRefetchModels } from "./useModels"; import { useRefetchModels } from "./useModels";
import { useStreamingContext } from "@/contexts/StreamingContext"; import { useStreamingContext } from "@/contexts/StreamingContext";
import { useSettings } from "./useSettings"; import { useSettings } from "./useSettings";
import { getModelCapabilities } from "@/api";
export const useChats = () => { export const useChats = () => {
return useQuery({ return useQuery({
@@ -607,24 +606,6 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["staleModels"], newStaleMap); queryClient.setQueryData(["staleModels"], newStaleMap);
queryClient.invalidateQueries({ queryKey: ["models"] }); queryClient.invalidateQueries({ queryKey: ["models"] });
// Fetch fresh capabilities for the downloaded model
getModelCapabilities(selectedModel.model)
.then((capabilities) => {
queryClient.setQueryData(
["modelCapabilities", selectedModel.model],
capabilities,
);
})
.catch((error) => {
console.error(
"Failed to fetch capabilities after download:",
error,
);
queryClient.invalidateQueries({
queryKey: ["modelCapabilities", selectedModel.model],
});
});
} }
break; break;
} }

View File

@@ -0,0 +1,114 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import { pullModel } from "@/api";
import { useSelectedModel } from "./useSelectedModel";
import { useSettings } from "./useSettings";
interface DownloadProgress {
status: string;
digest?: string;
total?: number;
completed?: number;
done?: boolean;
}
export function useDownloadModel(chatId?: string) {
const queryClient = useQueryClient();
const { selectedModel } = useSelectedModel(chatId);
const { setSettings } = useSettings();
const [downloadProgress, setDownloadProgress] =
useState<DownloadProgress | null>(null);
const [abortController, setAbortController] =
useState<AbortController | null>(null);
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
new Set(),
);
const mutation = useMutation({
mutationFn: async (modelName: string) => {
const controller = new AbortController();
setAbortController(controller);
setDownloadProgress({ status: "Starting download..." });
if (chatId) {
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
}
try {
for await (const progress of pullModel(modelName, controller.signal)) {
setDownloadProgress(progress);
if (progress.status === "success") {
// Update selected model to indicate it's now available locally
if (selectedModel && selectedModel.model === modelName) {
setSettings({ SelectedModel: modelName });
}
// Invalidate models query to refresh the list
await queryClient.invalidateQueries({ queryKey: ["models"] });
break;
}
}
} finally {
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
},
onSuccess: () => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
},
onError: (error: Error) => {
const status =
error.name === "AbortError" ? "Download cancelled" : "Download failed";
setDownloadProgress({ status, done: true });
// Clear error message after delay
const delay = error.name === "AbortError" ? 1500 : 3000;
setTimeout(() => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}, delay);
},
});
const cancelDownload = () => {
if (abortController) {
abortController.abort();
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
};
return {
downloadModel: mutation.mutate,
isDownloading:
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
downloadProgress:
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
error: mutation.error,
cancelDownload,
};
}

View File

@@ -1,347 +0,0 @@
import {
useRef,
useCallback,
useEffect,
useLayoutEffect,
useState,
useMemo,
} from "react";
import type { Message } from "@/gotypes";
// warning: this file is all claude code, needs to be looked into more closely
interface UseMessageAutoscrollOptions {
messages: Message[];
isStreaming: boolean;
chatId: string;
}
interface MessageAutoscrollBehavior {
handleNewUserMessage: () => void;
containerRef: React.RefObject<HTMLElement | null>;
spacerHeight: number;
}
export const useMessageAutoscroll = ({
messages,
isStreaming,
chatId,
}: UseMessageAutoscrollOptions): MessageAutoscrollBehavior => {
const containerRef = useRef<HTMLElement | null>(null);
const pendingScrollToUserMessage = useRef(false);
const [spacerHeight, setSpacerHeight] = useState(0);
const lastScrollHeightRef = useRef(0);
const lastScrollTopRef = useRef(0);
const [isActiveInteraction, setIsActiveInteraction] = useState(false);
const [hasSubmittedMessage, setHasSubmittedMessage] = useState(false);
const prevChatIdRef = useRef<string>(chatId);
// Find the last user message index from React state
const getLastUserMessageIndex = useCallback(() => {
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === "user") {
return i;
}
}
return -1;
}, [messages]);
const scrollToMessage = useCallback((messageIndex: number) => {
if (!containerRef.current || messageIndex < 0) {
return;
}
const container = containerRef.current;
// select the exact element by its data-message-index to avoid index mismatches
const targetElement = container.querySelector(
`[data-message-index="${messageIndex}"]`,
) as HTMLElement | null;
if (!targetElement) return;
const containerHeight = container.clientHeight;
const containerStyle = window.getComputedStyle(container);
const paddingTop = parseFloat(containerStyle.paddingTop) || 0;
const scrollHeight = container.scrollHeight;
const messageHeight = targetElement.offsetHeight;
// Check if the message is large, which is 70% of the container height
const isLarge = messageHeight > containerHeight * 0.7;
let targetPosition: number = targetElement.offsetTop - paddingTop; // default to scrolling the message to the top of the window
if (isLarge) {
// when the message is large scroll to the bottom of it
targetPosition = scrollHeight - containerHeight;
}
// Ensure we don't scroll past content boundaries
const maxScroll = scrollHeight - containerHeight;
const finalPosition = Math.min(Math.max(0, targetPosition), maxScroll);
container.scrollTo({
top: finalPosition,
behavior: "smooth",
});
}, []);
// Calculate and set the spacer height based on container dimensions
const updateSpacerHeight = useCallback(() => {
if (!containerRef.current) {
return;
}
const containerHeight = containerRef.current.clientHeight;
// Find the last user message to calculate spacer for
const lastUserIndex = getLastUserMessageIndex();
if (lastUserIndex < 0) {
setSpacerHeight(0);
return;
}
const messageElements = containerRef.current.querySelectorAll(
"[data-message-index]",
) as NodeListOf<HTMLElement>;
if (!messageElements || messageElements.length === 0) {
setSpacerHeight(0);
return;
}
const targetElement = containerRef.current.querySelector(
`[data-message-index="${lastUserIndex}"]`,
) as HTMLElement | null;
if (!targetElement) {
setSpacerHeight(0);
return;
}
const elementsAfter = Array.from(messageElements).filter((el) => {
const idx = Number(el.dataset.messageIndex);
return Number.isFinite(idx) && idx > lastUserIndex;
});
const contentHeightAfterTarget = elementsAfter.reduce(
(sum, el) => sum + el.offsetHeight,
0,
);
// Calculate the spacer height needed to position the user message at the top
// Add extra space for assistant response area
const targetMessageHeight = targetElement.offsetHeight;
// Calculate spacer to position the last user message at the top
// For new messages, we want them to appear at the top regardless of content after
// For large messages, we want to preserve the scroll-to-bottom behavior
// which shows part of the message and space for streaming response
let baseHeight: number;
if (contentHeightAfterTarget === 0) {
// No content after the user message (new message case)
// Position it at the top with some padding
baseHeight = Math.max(0, containerHeight - targetMessageHeight);
} else {
// Content exists after the user message
// Calculate spacer to position user message at top
baseHeight = Math.max(
0,
containerHeight - contentHeightAfterTarget - targetMessageHeight,
);
}
// Only apply spacer height when actively interacting (streaming or pending new message)
// When just viewing a chat, don't add extra space
if (!isActiveInteraction) {
setSpacerHeight(0);
return;
}
// Add extra space for assistant response only when streaming
const extraSpaceForAssistant = isStreaming ? containerHeight * 0.4 : 0;
const calculatedHeight = baseHeight + extraSpaceForAssistant;
setSpacerHeight(calculatedHeight);
}, [getLastUserMessageIndex, isStreaming, isActiveInteraction]);
// Handle new user message submission
const handleNewUserMessage = useCallback(() => {
// Mark that we're expecting a new message and should scroll to it
pendingScrollToUserMessage.current = true;
setIsActiveInteraction(true);
setHasSubmittedMessage(true);
}, []);
// Use layoutEffect to scroll immediately after DOM updates
useLayoutEffect(() => {
if (pendingScrollToUserMessage.current) {
// Find the last user message from current state
const targetUserIndex = getLastUserMessageIndex();
if (targetUserIndex >= 0) {
requestAnimationFrame(() => {
updateSpacerHeight();
requestAnimationFrame(() => {
scrollToMessage(targetUserIndex);
pendingScrollToUserMessage.current = false;
});
});
} else {
pendingScrollToUserMessage.current = false;
// Reset active interaction if no target found
setIsActiveInteraction(isStreaming);
}
}
}, [
messages,
getLastUserMessageIndex,
scrollToMessage,
updateSpacerHeight,
isStreaming,
]);
// Update active interaction state based on streaming and message submission
useEffect(() => {
if (
isStreaming ||
pendingScrollToUserMessage.current ||
hasSubmittedMessage
) {
setIsActiveInteraction(true);
} else {
setIsActiveInteraction(false);
}
}, [isStreaming, hasSubmittedMessage]);
useEffect(() => {
if (prevChatIdRef.current !== chatId) {
setIsActiveInteraction(false);
setHasSubmittedMessage(false);
prevChatIdRef.current = chatId;
}
}, [chatId]);
// Recalculate spacer height when messages change
useEffect(() => {
updateSpacerHeight();
}, [messages, updateSpacerHeight]);
// Use ResizeObserver to handle dynamic content changes
useEffect(() => {
if (!containerRef.current) return;
let resizeTimeout: ReturnType<typeof setTimeout>;
let immediateUpdate = false;
const resizeObserver = new ResizeObserver((entries) => {
// Check if this is a significant height change (like collapsing content)
let hasSignificantChange = false;
for (const entry of entries) {
const element = entry.target as HTMLElement;
if (
element.dataset.messageIndex &&
entry.contentRect.height !== element.offsetHeight
) {
const heightDiff = Math.abs(
entry.contentRect.height - element.offsetHeight,
);
if (heightDiff > 50) {
hasSignificantChange = true;
break;
}
}
}
// For significant changes, update immediately
if (hasSignificantChange || immediateUpdate) {
updateSpacerHeight();
immediateUpdate = false;
} else {
// For small changes (like streaming text), debounce
clearTimeout(resizeTimeout);
resizeTimeout = setTimeout(() => {
updateSpacerHeight();
}, 100);
}
});
// Also use MutationObserver for immediate attribute changes
const mutationObserver = new MutationObserver((mutations) => {
// Check if any mutations are related to expanding/collapsing
const hasToggle = mutations.some(
(mutation) =>
mutation.type === "attributes" &&
(mutation.attributeName === "class" ||
mutation.attributeName === "style" ||
mutation.attributeName === "open" ||
mutation.attributeName === "data-expanded"),
);
if (hasToggle) {
immediateUpdate = true;
updateSpacerHeight();
}
});
// Observe the container and all messages
resizeObserver.observe(containerRef.current);
mutationObserver.observe(containerRef.current, {
attributes: true,
subtree: true,
attributeFilter: ["class", "style", "open", "data-expanded"],
});
// Observe all message elements for size changes
const messageElements = containerRef.current.querySelectorAll(
"[data-message-index]",
);
messageElements.forEach((element) => {
resizeObserver.observe(element);
});
return () => {
clearTimeout(resizeTimeout);
resizeObserver.disconnect();
mutationObserver.disconnect();
};
}, [messages, updateSpacerHeight]);
// Track scroll position
useEffect(() => {
if (!containerRef.current) return;
const container = containerRef.current;
const handleScroll = () => {
lastScrollTopRef.current = container.scrollTop;
lastScrollHeightRef.current = container.scrollHeight;
};
container.addEventListener("scroll", handleScroll);
// Initialize scroll tracking
lastScrollTopRef.current = container.scrollTop;
lastScrollHeightRef.current = container.scrollHeight;
return () => {
container.removeEventListener("scroll", handleScroll);
};
}, []);
// Cleanup on unmount
useEffect(() => {
return () => {
pendingScrollToUserMessage.current = false;
};
}, []);
return useMemo(
() => ({
handleNewUserMessage,
containerRef,
spacerHeight,
}),
[handleNewUserMessage, containerRef, spacerHeight],
);
};

View File

@@ -1,20 +1,29 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api"; import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
export function useUser() { export function useUser() {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
// Wait for initial data to be loaded
useEffect(() => {
const initialPromise = window.__initialUserDataPromise;
if (initialPromise) {
initialPromise.finally(() => {
setInitialDataLoaded(true);
});
} else {
setInitialDataLoaded(true);
}
}, []);
const userQuery = useQuery({ const userQuery = useQuery({
queryKey: ["user"], queryKey: ["user"],
queryFn: async () => { queryFn: () => fetchUser(),
const result = await fetchUser();
return result;
},
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
retry: 10, initialData: null, // Start with null to prevent flashing
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
refetchOnMount: true, // Always fetch when component mounts
}); });
// Mutation to refresh user data // Mutation to refresh user data
@@ -40,15 +49,14 @@ export function useUser() {
}, },
}); });
const isLoading = userQuery.isLoading || userQuery.isFetching;
const isAuthenticated = Boolean(userQuery.data?.name);
return { return {
user: userQuery.data, user: userQuery.data,
isLoading, isLoading:
!initialDataLoaded ||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
isError: userQuery.isError, isError: userQuery.isError,
error: userQuery.error, error: userQuery.error,
isAuthenticated, isAuthenticated: Boolean(userQuery.data?.name),
refreshUser: refreshUser.mutate, refreshUser: refreshUser.mutate,
isRefreshing: refreshUser.isPending, isRefreshing: refreshUser.isPending,
refetchUser: userQuery.refetch, refetchUser: userQuery.refetch,

View File

@@ -1,13 +0,0 @@
// API configuration
const DEV_API_URL = "http://127.0.0.1:3001";
// Base URL for fetch API calls (can be relative in production)
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
// Full host URL for Ollama client (needs full origin in production)
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;
export const OLLAMA_DOT_COM =
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";

View File

@@ -147,7 +147,6 @@ export const highlighterPromise = createHighlighter({
"c", "c",
"cpp", "cpp",
"sql", "sql",
"swift",
"yaml", "yaml",
"markdown", "markdown",
], ],

View File

@@ -1,5 +1,4 @@
import { Ollama } from "ollama/browser"; import { Ollama } from "ollama/browser";
import { OLLAMA_HOST } from "./config";
let _ollamaClient: Ollama | null = null; let _ollamaClient: Ollama | null = null;
@@ -7,7 +6,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
get(_target, prop) { get(_target, prop) {
if (!_ollamaClient) { if (!_ollamaClient) {
_ollamaClient = new Ollama({ _ollamaClient = new Ollama({
host: OLLAMA_HOST, host: window.location.origin,
}); });
} }
const value = _ollamaClient[prop as keyof Ollama]; const value = _ollamaClient[prop as keyof Ollama];

View File

@@ -0,0 +1,5 @@
import { clsx, type ClassValue } from "clsx";
export function cn(...inputs: ClassValue[]) {
return clsx(inputs);
}

View File

@@ -5,6 +5,13 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { routeTree } from "./routeTree.gen"; import { routeTree } from "./routeTree.gen";
import { fetchUser } from "./api"; import { fetchUser } from "./api";
import { StreamingProvider } from "./contexts/StreamingContext"; import { StreamingProvider } from "./contexts/StreamingContext";
import { User } from "@/gotypes";
declare global {
interface Window {
__initialUserDataPromise?: Promise<User | null>;
}
}
const queryClient = new QueryClient({ const queryClient = new QueryClient({
defaultOptions: { defaultOptions: {
@@ -17,11 +24,27 @@ const queryClient = new QueryClient({
}, },
}); });
fetchUser().then((userData) => { // Track initial user data fetch
if (userData) { let initialUserDataPromise: Promise<User | null> | null = null;
// Initialize user data on app startup
const initializeUserData = async () => {
try {
const userData = await fetchUser();
queryClient.setQueryData(["user"], userData); queryClient.setQueryData(["user"], userData);
return userData;
} catch (error) {
console.error("Error initializing user data:", error);
queryClient.setQueryData(["user"], null);
return null;
} }
}); };
// Start initialization immediately and track the promise
initialUserDataPromise = initializeUserData();
// Export the promise so hooks can await it
window.__initialUserDataPromise = initialUserDataPromise;
const router = createRouter({ const router = createRouter({
routeTree, routeTree,

View File

@@ -102,13 +102,14 @@ type HealthResponse struct {
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"` Name string `json:"name"`
Bio string `json:"bio,omitempty"` Email string `json:"email"`
AvatarURL string `json:"avatarurl,omitempty"` AvatarURL string `json:"avatarURL"`
FirstName string `json:"firstname,omitempty"` Plan string `json:"plan"`
LastName string `json:"lastname,omitempty"` Bio string `json:"bio"`
Plan string `json:"plan,omitempty"` FirstName string `json:"firstName"`
LastName string `json:"lastName"`
OverThreshold bool `json:"overThreshold"`
} }
type Attachment struct { type Attachment struct {

View File

@@ -12,17 +12,18 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url"
"os" "os"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/server" "github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools" "github.com/ollama/ollama/app/tools"
@@ -117,66 +118,40 @@ func (s *Server) log() *slog.Logger {
// ollamaProxy creates a reverse proxy handler to the Ollama server // ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler { func (s *Server) ollamaProxy() http.Handler {
var ( ollamaHost := os.Getenv("OLLAMA_HOST")
proxy http.Handler if ollamaHost == "" {
proxyMu sync.Mutex ollamaHost = "http://127.0.0.1:11434"
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyMu.Lock()
p := proxy
proxyMu.Unlock()
if p == nil {
proxyMu.Lock()
if proxy == nil {
var err error
for i := range 2 {
if i > 0 {
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
time.Sleep(1 * time.Second)
} }
err = WaitForServer(context.Background(), 10*time.Second) if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
if err == nil { ollamaHost = "http://" + ollamaHost
break
}
} }
target, err := url.Parse(ollamaHost)
if err != nil { if err != nil {
proxyMu.Unlock() s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
s.log().Error("ollama server not ready after retries", "error", err) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable) http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
return })
} }
target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String()) s.log().Info("configuring ollama proxy", "target", target.String())
newProxy := httputil.NewSingleHostReverseProxy(target) proxy := httputil.NewSingleHostReverseProxy(target)
originalDirector := newProxy.Director originalDirector := proxy.Director
newProxy.Director = func(req *http.Request) { proxy.Director = func(req *http.Request) {
originalDirector(req) originalDirector(req)
req.Host = target.Host req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host) s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
} }
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String()) s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
} }
proxy = newProxy return proxy
p = newProxy
} else {
p = proxy
}
proxyMu.Unlock()
}
p.ServeHTTP(w, r)
})
} }
type errHandlerFunc func(http.ResponseWriter, *http.Request) error type errHandlerFunc func(http.ResponseWriter, *http.Request) error
@@ -289,10 +264,11 @@ func (s *Server) Handler() http.Handler {
ollamaProxy := s.ollamaProxy() ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy) mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy) mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/version", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy) mux.Handle("GET /api/v1/me", handle(s.me))
mux.Handle("POST /api/me", ollamaProxy) mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
mux.Handle("POST /api/signout", ollamaProxy) mux.Handle("GET /api/v1/connect", handle(s.connectURL))
mux.Handle("GET /api/v1/health", handle(s.health))
// React app - catch all non-API routes and serve the React app // React app - catch all non-API routes and serve the React app
mux.Handle("GET /", s.appHandler()) mux.Handle("GET /", s.appHandler())
@@ -362,7 +338,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
} }
// UserData fetches user data from ollama.com API for the current ollama key // UserData fetches user data from ollama.com API for the current ollama key
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) { func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me") resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err) return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
@@ -373,7 +349,7 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
} }
var user api.UserResponse var user responses.User
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse user response: %w", err) return nil, fmt.Errorf("failed to parse user response: %w", err)
} }
@@ -392,27 +368,29 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return &user, nil return &user, nil
} }
// WaitForServer waits for the Ollama server to be ready func waitForServer(ctx context.Context) error {
func WaitForServer(ctx context.Context, timeout time.Duration) error { timeout := time.Now().Add(10 * time.Second)
deadline := time.Now().Add(timeout) // TODO: this avoids an error on first load of the app
for time.Now().Before(deadline) { // however we should either show a loading state or
// wait for the Ollama server to be ready before redirecting
for {
c, err := api.ClientFromEnvironment() c, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
if _, err := c.Version(ctx); err == nil { if _, err := c.Version(ctx); err == nil {
slog.Debug("ollama server is ready") break
return nil }
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
return errors.New("timeout waiting for Ollama server to be ready") return nil
} }
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error { func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
if err := WaitForServer(r.Context(), 10*time.Second); err != nil { waitForServer(r.Context())
return err
}
id, err := uuid.NewV7() id, err := uuid.NewV7()
if err != nil { if err != nil {
@@ -997,7 +975,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
for _, toolCall := range res.Message.ToolCalls { for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed // continues loop as tools were executed
toolsExecuted = true toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap()) result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil { if err != nil {
errContent := fmt.Sprintf("Error: %v", err) errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil) toolErrMsg := store.NewMessage("tool", errContent, nil)
@@ -1460,6 +1438,129 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
}) })
} }
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
user, err := s.UserData(r.Context())
if err != nil {
// If fetching from API fails, try to return cached user data if available
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
s.log().Info("API request failed, returning cached user data", "error", err)
responseUser := &responses.User{
Name: cachedUser.Name,
Email: cachedUser.Email,
Plan: cachedUser.Plan,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responseUser)
}
s.log().Error("failed to get user data", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get user data",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(user)
}
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
if err := s.Store.ClearUser(); err != nil {
s.log().Warn("failed to clear cached user data", "error", err)
}
// Get the SSH public key to encode for the delete request
pubKey, err := ollamaAuth.GetPublicKey()
if err != nil {
s.log().Error("failed to get public key", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get public key",
})
}
// Encode the key using base64 URL encoding
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
if err != nil {
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.log().Error("disconnect request failed", "status", resp.StatusCode)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
}
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
if err != nil {
s.log().Error("failed to build connect URL", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to build connect URL",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{
"connect_url": connectURL,
})
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
healthy := false
c, err := api.ClientFromEnvironment()
if err == nil {
if _, err := c.Version(r.Context()); err == nil {
healthy = true
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responses.HealthResponse{
Healthy: healthy,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error { func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond) ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel() defer cancel()
@@ -1558,13 +1659,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool.Function.Parameters.Type = "object" tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{} tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = api.NewToolPropertiesMap() tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok { if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object") tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok { if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = api.NewToolPropertiesMap() tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
for propName, propDef := range props { for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok { if propMap, ok := propDef.(map[string]any); ok {
@@ -1572,7 +1673,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")}, Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""), Description: getStringFromMap(propMap, "description", ""),
} }
tool.Function.Parameters.Properties.Set(propName, prop) tool.Function.Parameters.Properties[propName] = prop
} }
} }
} }

View File

@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
case uint32(UI_REQUEST_MSG_ID): case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread // Requests for the UI must always come from the main event thread
l := int(wParam) l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
t.app.UIRun(path) t.app.UIRun(path)
case WM_COPYDATA: case WM_COPYDATA:
// Handle URL scheme requests from other instances // Handle URL scheme requests from other instances
if lParam != 0 { if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
if cds.DwData == 1 { // Our identifier for URL scheme messages if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string // Convert the data back to string
data := make([]byte, cds.CbData) data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
urlScheme := string(data) urlScheme := string(data)
handleURLSchemeRequest(urlScheme) handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success lResult = 1 // Return non-zero to indicate success

View File

@@ -1,115 +0,0 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats
* Detailed performance metrics (prefill, generate, load, total durations)
## Building from Source
```
go build -o ollama-bench bench.go
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
```
go run bench.go -model gpt-oss:20b -epochs 3
```
## Usage
### Basic Example
```
./ollama-bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Markdown Format
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
```
### CSV Format
Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00
gpt-oss:20b,generate,512,19531.25,51200.00
gpt-oss:20b,load,1,1500000000,0
```
## Metrics Explained
The tool reports four types of metrics for each model:
* prefill: Time spent processing the prompt
* generate: Time spent generating the response
* load: Model loading time (one-time cost)
* total: Total request duration

View File

@@ -1,321 +0,0 @@
package main
import (
"cmp"
"context"
"flag"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
}
type Metrics struct {
Model string
Step string
Count int
Duration time.Duration
}
var once sync.Once
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count)
}
} else {
var suffix string
if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
}
}
case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
var tokensPerSec float64
if m.Count > 0 {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
}
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkChat(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
imgData, err = readImage(*fOpt.imageFile)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
return err
}
}
if *fOpt.debug && imgData != nil {
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
return err
}
var out io.Writer = os.Stdout
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
return err
}
defer f.Close()
out = f
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
continue
}
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
continue
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
continue
}
metrics := []Metrics{
{
Model: model,
Step: "prefill",
Count: responseMetrics.PromptEvalCount,
Duration: responseMetrics.PromptEvalDuration,
},
{
Model: model,
Step: "generate",
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "load",
Count: 1,
Duration: responseMetrics.LoadDuration,
},
{
Model: model,
Step: "total",
Count: 1,
Duration: responseMetrics.TotalDuration,
},
}
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return api.ImageData(data), nil
}
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
}
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
}
flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
if len(*fOpt.models) == 0 {
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
flag.Usage()
return
}
BenchmarkChat(fOpt)
}

View File

@@ -1,463 +0,0 @@
package main
import (
"bytes"
"crypto/rand"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func createTestFlagOptions() flagOptions {
models := "test-model"
format := "benchstat"
epochs := 1
maxTokens := 100
temperature := 0.7
seed := 42
timeout := 30
prompt := "test prompt"
imageFile := ""
keepAlive := 5.0
verbose := false
debug := false
return flagOptions{
models: &models,
format: &format,
epochs: &epochs,
maxTokens: &maxTokens,
temperature: &temperature,
seed: &seed,
timeout: &timeout,
prompt: &prompt,
imageFile: &imageFile,
keepAlive: &keepAlive,
verbose: &verbose,
debug: &debug,
}
}
func captureOutput(f func()) string {
oldStdout := os.Stdout
oldStderr := os.Stderr
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
}()
r, w, _ := os.Pipe()
os.Stdout = w
os.Stderr = w
f()
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
for _, resp := range responses {
jsonData, err := json.Marshal(resp)
if err != nil {
t.Errorf("Failed to marshal response: %v", err)
return
}
w.Write(jsonData)
w.Write([]byte("\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(10 * time.Millisecond) // Simulate some delay
}
}))
}
func TestBenchmarkChat_Success(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 1",
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 2",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
}
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
t.Errorf("Expected output to contain generate metrics, got: %s", output)
}
if !strings.Contains(output, "ns/token") {
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
}
}
func TestBenchmarkChat_ServerError(t *testing.T) {
fOpt := createTestFlagOptions()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error", http.StatusInternalServerError)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
t.Errorf("Expected error message about chat failure, got: %s", output)
}
}
func TestBenchmarkChat_Timeout(t *testing.T) {
fOpt := createTestFlagOptions()
shortTimeout := 1 // Very short timeout
fOpt.timeout = &shortTimeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate a long delay that will cause timeout
time.Sleep(2 * time.Second)
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Chat request timed out") {
t.Errorf("Expected timeout error message, got: %s", output)
}
}
func TestBenchmarkChat_NoMetrics(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: false, // Never sends Done=true
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "ERROR: No metrics received") {
t.Errorf("Expected no metrics error message, got: %s", output)
}
}
func TestBenchmarkChat_MultipleModels(t *testing.T) {
fOpt := createTestFlagOptions()
models := "model1,model2"
epochs := 2
fOpt.models = &models
fOpt.epochs = &epochs
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
response := api.ChatResponse{
Model: req.Model,
Message: api.Message{
Role: "assistant",
Content: "test response for " + req.Model,
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
// Should be called 4 times (2 models × 2 epochs)
if callCount != 4 {
t.Errorf("Expected 4 API calls, got %d", callCount)
}
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
t.Errorf("Expected output for both models, got: %s", output)
}
}
func TestBenchmarkChat_WithImage(t *testing.T) {
fOpt := createTestFlagOptions()
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
tmpfileName := tmpfile.Name()
fOpt.imageFile = &tmpfileName
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the request contains image data
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
t.Error("Expected request to contain images")
}
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response with image",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
t.Errorf("Expected benchmark output, got: %s", output)
}
}
func TestBenchmarkChat_ImageError(t *testing.T) {
randFileName := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 8
result := make([]byte, length)
rand.Read(result) // Fill with random bytes
for i := range result {
result[i] = charset[result[i]%byte(len(charset))]
}
return string(result) + ".txt"
}
fOpt := createTestFlagOptions()
imageFile := randFileName()
fOpt.imageFile = &imageFile
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err == nil {
t.Error("Expected error from image reading, got nil")
}
})
if !strings.Contains(output, "ERROR: Couldn't read image") {
t.Errorf("Expected image read error message, got: %s", output)
}
}
func TestReadImage_Success(t *testing.T) {
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
imgData, err := readImage(tmpfile.Name())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if imgData == nil {
t.Error("Expected image data, got nil")
}
expected := api.ImageData(content)
if string(imgData) != string(expected) {
t.Errorf("Expected image data %v, got %v", expected, imgData)
}
}
func TestReadImage_FileNotFound(t *testing.T) {
imgData, err := readImage("nonexistentfile.jpg")
if err == nil {
t.Error("Expected error for non-existent file, got nil")
}
if imgData != nil {
t.Error("Expected nil image data for non-existent file")
}
}
func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions()
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if options["num_predict"] != *fOpt.maxTokens {
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
}
if options["temperature"] != *fOpt.temperature {
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
}
if options["seed"] != *fOpt.seed {
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
}
}

View File

@@ -45,9 +45,6 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
) )
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -98,11 +95,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
if os.IsNotExist(err) { if os.IsNotExist(err) {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n") reader = strings.NewReader("FROM .\n")
} else { } else {
return errModelfileNotFound return errModelfileNotFound
@@ -464,7 +456,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
name := args[0] name := args[0]
info, err := func() (*api.ShowResponse, error) { info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name} showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq) info, err := client.Show(cmd.Context(), showReq)
@@ -526,19 +517,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions) return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
} }
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive { if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError var sErr api.AuthorizationError
@@ -565,11 +543,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
} }
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
return generate(cmd, opts) return generate(cmd, opts)
@@ -673,11 +646,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
msg := resp.Status bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
@@ -974,9 +943,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
} }
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel}) rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return return
}) })
@@ -1464,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary() latest.Summary()
} }
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil return &api.Message{Role: role, Content: fullResponse.String()}, nil
} }
func generate(cmd *cobra.Command, opts runOptions) error { func generate(cmd *cobra.Command, opts runOptions) error {
@@ -1785,12 +1751,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",

View File

@@ -291,31 +291,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
t.Run("min version", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Requires: "0.14.0",
}, false, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
requires 0.14.0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
} }
func TestDeleteHandler(t *testing.T) { func TestDeleteHandler(t *testing.T) {
@@ -1547,79 +1522,6 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
} }
} }
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) { func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy // Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"} originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -40,7 +40,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")

View File

@@ -6,14 +6,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"iter"
"log/slog" "log/slog"
"maps"
"os" "os"
"slices" "slices"
"strings" "strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct { TextModel struct {
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
} `json:"text_config"` } `json:"text_config"`
} }
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"` } `json:"lora_parameters"`
} }
type KV map[string]any func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
"general.file_type": uint32(1), "general.file_type": uint32(1),
"general.quantization_version": uint32(2), "general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre, "tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv return kv
} }
func (p AdapterParameters) KV() KV { func (p AdapterParameters) KV() ggml.KV {
var alpha float32 var alpha float32
if p.LoraParameters.Alpha == 0 { if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha) alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha alpha = p.LoraParameters.Alpha
} }
kv := KV{ kv := ggml.KV{
"adapter.lora.alpha": alpha, "adapter.lora.alpha": alpha,
"adapter.type": "lora", "adapter.type": "lora",
"general.file_type": uint32(1), "general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
} }
} }
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface { type ModelConverter interface {
ModelKV // KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface { type AdapterConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(ofs.Config) KV KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string Replacements() []string
} }
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error { func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json") bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil { if err != nil {
return err return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err return err
} }
arch := baseKV.Architecture() arch, ok := baseKV["general.architecture"]
if arch == "" { if !ok {
return errors.New("architecture not set for the base model") return errors.New("architecture not set for the base model")
} }
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts)) return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
} }
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json") bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return nil, nil, err return err
} }
var p ModelParameters var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil { if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err return err
} }
if len(p.Architectures) < 1 { if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture") return errors.New("unknown architecture")
} }
var conv ModelConverter var conv ModelConverter
@@ -277,8 +182,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &llama4Model{} conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{} conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtralModel{} conv = &mixtralModel{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
@@ -297,37 +200,29 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &qwen25VLModel{} conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{} conv = &qwen3VLModel{}
case "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "NomicBertModel", "NomicBertMoEModel":
conv = &nomicbertModel{}
case "CohereForCausalLM": case "CohereForCausalLM":
conv = &commandrModel{} conv = &commandrModel{}
case "GptOssForCausalLM": case "GptOssForCausalLM":
conv = &gptossModel{} conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default: default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
} }
if err := json.Unmarshal(bts, conv); err != nil { if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err return err
} }
if t, ok := conv.(moreParser); ok { if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil { if err := t.parseMore(fsys); err != nil {
return nil, nil, err return err
} }
} }
t, err := parseTokenizer(fsys, conv.specialTokenTypes()) t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil { if err != nil {
return nil, nil, err return err
} }
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize)) vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +244,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default: default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
} }
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...)) ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil { if err != nil {
@@ -371,7 +253,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts)) return writeFile(f, conv.KV(t), conv.Tensors(ts))
} }
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error { func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts { for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape) ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape) slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil return nil
} }
func (p *bertModel) KV(t *Tokenizer) KV { func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert" kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil) var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV { func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r" kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r" kv["general.name"] = "command-r"

View File

@@ -1,173 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type deepseek2Model struct {
ModelParameters // architectures, vocab_size
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
ScoringFunc string `json:"scoring_func"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
RopeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
MScaleAllDim float32 `json:"mscale_all_dim"`
} `json:"rope_scaling"`
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"
kv["deepseek2.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["deepseek2.attention.head_count"] = numHeads
kv["deepseek2.attention.head_count_kv"] = numKVHeads
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
kv["deepseek2.attention.value_length"] = p.VHeadDim
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
kv["deepseek2.embedding_length"] = p.HiddenSize
kv["deepseek2.expert_count"] = p.ExpertCount
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
var scoringFunc uint32
switch p.ScoringFunc {
case "softmax":
// not currently supported in the model, but needed for Deepseek-OCR
scoringFunc = 1
case "sigmoid":
scoringFunc = 2
}
kv["deepseek2.expert_gating_func"] = scoringFunc
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
kv["tokenizer.ggml.pre"] = "deepseek-v3"
return kv
}
func (p *deepseek2Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"language_model.", "",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -1,136 +0,0 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil) var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV { func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma" kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct { type gemma2Model struct {
gemmaModel gemmaModel
SlidingWindow uint32 `json:"sliding_window"` SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"` FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
} }
func (p *gemma2Model) KV(t *Tokenizer) KV { func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2" kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
"github.com/pdevine/tensor/native" "github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil) var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV { func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV() kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2" kv["general.architecture"] = "gemma2"
return kv return kv

View File

@@ -2,7 +2,8 @@ package convert
import ( import (
"cmp" "cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
) )
type gemma3Model struct { type gemma3Model struct {
@@ -32,19 +33,9 @@ type gemma3Model struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"` FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"` RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"` RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"` SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"` MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
RopeScaling *struct {
Type string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
} `json:"rope_scaling"`
} }
const ( const (
@@ -53,7 +44,7 @@ const (
gemma27BLayerCount = 62 gemma27BLayerCount = 62
) )
func (p *gemma3Model) KV(t *Tokenizer) KV { func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3" kv["general.architecture"] = "gemma3"
@@ -90,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) KV {
kv["gemma3.attention.key_length"] = p.HeadDim kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
// The sliding window pattern is either provided as the sliding_window_pattern
// key (an int) or as the layer_types key (a list of strings).
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for i := range numBlocks {
var isLocal bool
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
isLocal = p.LayerTypes[i] == "sliding_attention"
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
isLocal = (i+1)%*p.SlidingWindowPattern != 0
}
if !yield(isLocal) {
break
}
}
})
}
if p.FinalLogitSoftcap > 0 {
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
}
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0) kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0) kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
kv["gemma3.rope.scaling.type"] = "yarn"
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
}
kv["gemma3.embedding_length"] = p.HiddenSize kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize kv["gemma3.feed_forward_length"] = p.IntermediateSize
default: default:

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"` VisionModel struct{} `json:"vision_config"`
} }
func (m *gemma3nModel) KV(t *Tokenizer) KV { func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t) kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n" kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) { kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil) var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV { func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t) kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss" kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4) kv["general.file_type"] = uint32(4)
@@ -110,12 +110,9 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") { if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: name, Name: name + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4, WriterTo: mxfp4,
@@ -124,12 +121,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps // gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1), Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{ }, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1), Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil) var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV { func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama" kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
} }
// KV implements ModelConverter. // KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV { func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4" kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
"github.com/pdevine/tensor/native" "github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil) var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV { func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV() kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama" kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count") kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv") kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32) p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv return kv
} }

View File

@@ -29,17 +29,6 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"` SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"` HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
} `json:"text_config"` } `json:"text_config"`
VisionModel struct { VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -52,15 +41,12 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"` HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"` } `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"` MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"` ProjectorHiddenAct string `json:"projector_hidden_act"`
} }
func (p *mistral3Model) KV(t *Tokenizer) KV { func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3" kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize kv["mistral3.vocab_size"] = p.TextModel.VocabSize
@@ -75,25 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads) kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta) kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
}
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
}
// Vision configuration // Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
@@ -105,7 +74,7 @@ func (p *mistral3Model) KV(t *Tokenizer) KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta) kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration // Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@@ -1,181 +0,0 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"` NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
} }
func (p *mixtralModel) KV(t *Tokenizer) KV { func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t) kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 { if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"` } `json:"vision_config"`
} }
func (m *mllamaModel) KV(t *Tokenizer) KV { func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t) kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama" kv["general.architecture"] = "mllama"

View File

@@ -1,213 +0,0 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type nomicbertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeFreqBase float32 `json:"rope_theta"`
normalizeEmbeddings bool
PoolingType uint32
// MoE parameters (only present in v2 models)
NumExperts uint32 `json:"num_local_experts"`
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
}
var (
_ ModelConverter = (*nomicbertModel)(nil)
_ moreParser = (*nomicbertModel)(nil)
)
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)
arch := "nomic-bert"
if p.MoEEveryNLayers > 0 {
arch += "-moe"
}
kv["general.architecture"] = arch
kv["attention.causal"] = false
kv["pooling_type"] = p.PoolingType
kv["normalize_embeddings"] = p.normalizeEmbeddings
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
kv["context_length"] = contextLength
}
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
kv["embedding_length"] = p.HiddenSize
}
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
kv["feed_forward_length"] = p.IntermediateSize
}
if headCount := p.NumAttentionHeads; headCount > 0 {
kv["attention.head_count"] = p.NumAttentionHeads
}
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
kv["attention.head_count_kv"] = p.NumKeyValueHeads
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.RopeFreqBase > 0 {
kv["rope.freq_base"] = p.RopeFreqBase
}
// MoE specific parameters (only if MoE is enabled)
if p.NumExperts > 0 {
kv["expert_count"] = p.NumExperts
}
if p.NumExpertsUsed > 0 {
kv["expert_used_count"] = p.NumExpertsUsed
}
if p.MoEEveryNLayers > 0 {
kv["moe_every_n_layers"] = p.MoEEveryNLayers
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
switch {
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
// noop - keep special tokens as-is
case strings.HasPrefix(e, "##"):
t.Tokens[i] = e[2:]
default:
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (nomicbertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"attention.self.qkv", "attn_qkv",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"mlp.up", "ffn_up",
"mlp.down", "ffn_down",
"mlp.router", "ffn_gate_inp",
"mlp.experts.up", "ffn_up_exps",
"mlp.experts.down", "ffn_down_exps",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}

View File

@@ -1,117 +0,0 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
kv["olmo3.embedding_length"] = p.HiddenSize
kv["olmo3.feed_forward_length"] = p.IntermediateSize
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo3.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
}
}
if p.RMSNormEPS > 0 {
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.SlidingWindow > 0 {
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil) var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV { func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3" kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil) var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV { func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t) kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2" kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil) var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV { func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t) kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl" kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
} }
// KV implements ModelConverter. // KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV { func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3" arch := "qwen3"
if q.NumExperts > 0 { if q.NumExperts > 0 {
arch += "moe" arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel) return json.Unmarshal(bts, &m.VisionModel)
} }
func (m *qwen3VLModel) KV(t *Tokenizer) KV { func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t) kv := m.qwen3Model.KV(t)
arch := "qwen3vl" arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"` Shape []int `json:"shape"`
} }
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) { func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16") f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors() return r, m.KV(), m.Tensors()
} }
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string { func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string) actual := make(map[string]string)
for k := range kv.Keys() { for k, v := range kv {
v := kv.Value(k)
if s, ok := v.(json.Marshaler); !ok { if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v) actual[k] = fmt.Sprintf("%v", v)
} else { } else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) { func TestConvertAdapter(t *testing.T) {
type AdapterCase struct { type AdapterCase struct {
Name string Name string
BaseKV KV BaseKV map[string]any
Expected map[string]string Expected map[string]string
} }

View File

@@ -44,10 +44,7 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" || t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" || t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" || t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" || t.name == "v.post_tile_position_embd.weight" {
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
// these tensors are always F32 // these tensors are always F32
return tensorKindFP32 return tensorKindFP32
} }

View File

@@ -96,10 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 { func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind() kind := st.tensorBase.Kind()
if st.dtype == "BF16" && if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
kind = tensorKindBF16 kind = tensorKindBF16
} }

View File

@@ -2,12 +2,10 @@ package convert
import ( import (
"cmp" "cmp"
"errors"
"io" "io"
"iter" "iter"
"path" "path"
"slices" "slices"
"strconv"
"strings" "strings"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
@@ -96,26 +94,6 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched return matched
}) })
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 { if len(matched) > 0 {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: merges[i].name, Name: merges[i].name,

View File

@@ -3,10 +3,8 @@ package convert
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"iter" "iter"
"math/rand/v2"
"slices" "slices"
"strings" "strings"
"testing" "testing"
@@ -953,45 +951,3 @@ func TestMerge(t *testing.T) {
} }
}) })
} }
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}

View File

@@ -49,8 +49,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// temporary fix to handle gemma3 broken configs // temporary fix to handle gemma3 broken configs
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
} }

View File

@@ -2,7 +2,6 @@ package discover
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -11,21 +10,12 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
) )
func GetCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
mem, err := getCPUMem()
if err != nil {
return memInfo{}, err
}
return getCPUMemByCgroups(mem), nil
}
func getCPUMem() (memInfo, error) {
var mem memInfo var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64 var total, available, free, buffers, cached, freeSwap uint64
f, err := os.Open("/proc/meminfo") f, err := os.Open("/proc/meminfo")
@@ -66,32 +56,6 @@ func getCPUMem() (memInfo, error) {
return mem, nil return mem, nil
} }
func getCPUMemByCgroups(mem memInfo) memInfo {
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
if err == nil {
mem.TotalMemory = total
}
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
if err == nil {
mem.FreeMemory = mem.TotalMemory - used
}
return mem
}
func getUint64ValueFromFile(path string) (uint64, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
s := bufio.NewScanner(f)
for s.Scan() {
line := s.Text()
return strconv.ParseUint(line, 10, 64)
}
return 0, errors.New("empty file content")
}
const CpuInfoFilename = "/proc/cpuinfo" const CpuInfoFilename = "/proc/cpuinfo"
type linuxCpuInfo struct { type linuxCpuInfo struct {
@@ -110,41 +74,7 @@ func GetCPUDetails() []CPU {
return nil return nil
} }
defer file.Close() defer file.Close()
cpus := linuxCPUDetails(file) return linuxCPUDetails(file)
return overwriteThreadCountByLinuxCgroups(cpus)
}
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
file, err := os.Open("/sys/fs/cgroup/cpu.max")
if err != nil {
return cpus
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if sl := strings.Split(line, " "); len(sl) == 2 {
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
return cpus
}
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU unit micro secs", "error", err)
return cpus
}
threads := int(max(allowdUs/unitUs, 1))
cpu := cpus[0]
cpu.CoreCount = threads
cpu.ThreadCount = threads
return []CPU{cpu}
}
}
return cpus
} }
func linuxCPUDetails(file io.Reader) []CPU { func linuxCPUDetails(file io.Reader) []CPU {

View File

@@ -65,11 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
} }
slog.Info("discovering available GPUs...") slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
requested := envconfig.LLMLibrary() requested := envconfig.LLMLibrary()
jetpack := cudaJetpack() jetpack := cudaJetpack()
@@ -95,13 +90,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
var dirs []string var dirs []string
if dir != "" { if dir != "" {
if requested != "" && filepath.Base(dir) != requested { if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir) slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
continue continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack { } else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") { } else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1") slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue continue
@@ -121,7 +113,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
// In the second pass, we more deeply initialize the GPUs to weed out devices that // In the second pass, we more deeply initialize the GPUs to weed out devices that
// aren't supported by a given library. We run this phase in parallel to speed up discovery. // aren't supported by a given library. We run this phase in parallel to speed up discovery.
// Only devices that need verification are included in this pass // Only devices that need verification are included in this pass
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices)) slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second) ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -129,25 +121,15 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{} supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices { for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() { if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue continue
} }
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID) libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
defer wg.Done() defer wg.Done()
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true) extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
devices[i].AddInitValidation(extraEnvs) devices[i].AddInitValidation(extraEnvs)
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
slog.Debug("filtering device which didn't fully initialize", slog.Debug("filtering device which didn't fully initialize",
@@ -333,8 +315,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
defer cancel() defer cancel()
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct // Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment devFilter := ml.GetVisibleDevicesEnv(devices)
devFilter := ml.GetVisibleDevicesEnv(devices, false)
for dir := range libDirs { for dir := range libDirs {
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter) updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
@@ -468,37 +449,3 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
return devices return devices
} }
func overrideWarnings() {
anyFound := false
m := envconfig.AsMap()
for _, k := range []string{
"CUDA_VISIBLE_DEVICES",
"HIP_VISIBLE_DEVICES",
"ROCR_VISIBLE_DEVICES",
"GGML_VK_VISIBLE_DEVICES",
"GPU_DEVICE_ORDINAL",
"HSA_OVERRIDE_GFX_VERSION",
} {
if e, found := m[k]; found && e.Value != "" {
anyFound = true
slog.Warn("user overrode visible devices", k, e.Value)
}
}
if anyFound {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api) * [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile) * [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility) * [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources ### Resources

View File

@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional): Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema - `format`: the format to return a response in. Format can be `json` or a JSON schema
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`) - `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
@@ -507,7 +507,7 @@ The `message` object has the following fields:
Advanced parameters (optional): Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema. - `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"name": "get_weather", "name": "get_temperature",
"arguments": { "arguments": {
"city": "Toronto" "city": "Toronto"
} }
} },
} }
] ]
}, },
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{ {
"role": "tool", "role": "tool",
"content": "11 degrees celsius", "content": "11 degrees celsius",
"tool_name": "get_weather" "tool_name": "get_temperature",
} }
], ],
"stream": false, "stream": false,
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
- `template`: (optional) the prompt template for the model - `template`: (optional) the prompt template for the model
- `license`: (optional) a string or list of strings containing the license or licenses for the model - `license`: (optional) a string or list of strings containing the license or licenses for the model
- `system`: (optional) a string containing the system prompt for the model - `system`: (optional) a string containing the system prompt for the model
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters) - `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
- `messages`: (optional) a list of message objects used to create a conversation - `messages`: (optional) a list of message objects used to create a conversation
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `quantize` (optional): quantize a non-quantized (e.g. float16) model - `quantize` (optional): quantize a non-quantized (e.g. float16) model
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
Advanced parameters: Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding - `dimensions`: number of dimensions for the embedding
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
Advanced parameters: Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples ### Examples

View File

@@ -1,406 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

File diff suppressed because one or more lines are too long

View File

@@ -13,23 +13,9 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
## Generate embeddings ## Generate embeddings
Use `/api/embed` with a single string.
<Tabs> <Tabs>
<Tab title="CLI">
Generate embeddings directly from the command line:
```shell
ollama run embeddinggemma "Hello world"
```
You can also pipe text to generate embeddings:
```shell
echo "Hello world" | ollama run embeddinggemma
```
Output is a JSON array.
</Tab>
<Tab title="cURL"> <Tab title="cURL">
```shell ```shell
curl -X POST http://localhost:11434/api/embed \ curl -X POST http://localhost:11434/api/embed \

View File

@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
```shell ```shell
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{ curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3", "model": "qwen3",
"messages": [{"role": "user", "content": "What is the temperature in New York?"}], "messages": [{"role": "user", "content": "What's the temperature in New York?"}],
"stream": false, "stream": false,
"tools": [ "tools": [
{ {
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{ curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3", "model": "qwen3",
"messages": [ "messages": [
{"role": "user", "content": "What is the temperature in New York?"}, {"role": "user", "content": "What's the temperature in New York?"},
{ {
"role": "assistant", "role": "assistant",
"tool_calls": [ "tool_calls": [
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
} }
return temperatures.get(city, "Unknown") return temperatures.get(city, "Unknown")
messages = [{"role": "user", "content": "What is the temperature in New York?"}] messages = [{"role": "user", "content": "What's the temperature in New York?"}]
# pass functions directly as tools in the tools list or as a JSON schema # pass functions directly as tools in the tools list or as a JSON schema
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True) response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
}, },
] ]
const messages = [{ role: 'user', content: "What is the temperature in New York?" }] const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
const response = await ollama.chat({ const response = await ollama.chat({
model: 'qwen3', model: 'qwen3',
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
return temperatures.get(city, 'Unknown') return temperatures.get(city, 'Unknown')
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}] messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
while True: while True:
stream = chat( stream = chat(
@@ -684,7 +684,7 @@ const getTemperatureTool = {
} }
async function agentLoop() { async function agentLoop() {
const messages = [{ role: 'user', content: "What is the temperature in New York?" }] const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
while (true) { while (true) {
const stream = await ollama.chat({ const stream = await ollama.chat({

View File

@@ -36,6 +36,7 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}], }],
"stream": false "stream": false
}' }'
"
``` ```
</Tab> </Tab>
<Tab title="Python"> <Tab title="Python">

View File

@@ -9,9 +9,15 @@ sidebarTitle: Cloud
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer. Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
### Supported models Ollama currently supports the following cloud models, with more coming soon:
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud). - `deepseek-v3.1:671b-cloud`
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `kimi-k2:1t-cloud`
- `qwen3-coder:480b-cloud`
- `glm-4.6:cloud`
- `minimax-m2:cloud`
### Running Cloud models ### Running Cloud models

View File

@@ -49,8 +49,6 @@ Install prerequisites:
- [Ninja](https://github.com/ninja-build/ninja/releases) - [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support - (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
Then, configure and build the project: Then, configure and build the project:
@@ -59,17 +57,6 @@ cmake -B build
cmake --build build --config Release cmake --build build --config Release
``` ```
> Building for Vulkan requires VULKAN_SDK environment variable:
>
> PowerShell
> ```powershell
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
> ```
> CMD
> ```cmd
> set VULKAN_SDK=C:\VulkanSDK\<version>
> ```
> [!IMPORTANT] > [!IMPORTANT]
> Building for ROCm requires additional flags: > Building for ROCm requires additional flags:
> ``` > ```
@@ -78,7 +65,6 @@ cmake --build build --config Release
> ``` > ```
Lastly, run Ollama: Lastly, run Ollama:
```shell ```shell
@@ -98,9 +84,7 @@ Install prerequisites:
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html) - [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
- (Optional) NVIDIA GPU support - (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads) - [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
> [!IMPORTANT] > [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake. > Ensure prerequisites are in `PATH` before running CMake.

View File

@@ -32,9 +32,7 @@
"codeblocks": "system" "codeblocks": "system"
}, },
"contextual": { "contextual": {
"options": [ "options": ["copy"]
"copy"
]
}, },
"navbar": { "navbar": {
"links": [ "links": [
@@ -54,9 +52,7 @@
"display": "simple" "display": "simple"
}, },
"examples": { "examples": {
"languages": [ "languages": ["curl"]
"curl"
]
} }
}, },
"redirects": [ "redirects": [
@@ -101,7 +97,6 @@
{ {
"group": "Integrations", "group": "Integrations",
"pages": [ "pages": [
"/integrations/claude-code",
"/integrations/vscode", "/integrations/vscode",
"/integrations/jetbrains", "/integrations/jetbrains",
"/integrations/codex", "/integrations/codex",
@@ -144,8 +139,7 @@
"/api/streaming", "/api/streaming",
"/api/usage", "/api/usage",
"/api/errors", "/api/errors",
"/api/openai-compatibility", "/api/openai-compatibility"
"/api/anthropic-compatibility"
] ]
}, },
{ {

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs? ## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs. Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## Is my GPU compatible with Ollama? ## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu). Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?
@@ -57,13 +57,8 @@ ollama ps
``` ```
<Info> <Info>
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
**Output**: 100% GPU 4 minutes from now ```
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
</Info> </Info>
The `Processor` column will show which memory the model was loaded in to: The `Processor` column will show which memory the model was loaded in to:

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` | | 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` | | | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia) For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
### GPU Selection ### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library: Ollama supports the following AMD GPUs via the ROCm library:
> **NOTE:** > [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below. > Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support ## Vulkan GPU Support
> **NOTE:** > [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as > Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server) described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come [Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as `GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1` by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -1,69 +0,0 @@
---
title: Claude Code
---
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model

Some files were not shown because too many files have changed in this diff Show More