Compare commits

..

5 Commits

Author SHA1 Message Date
Bruce MacDonald
f2a4d058f9 gofmt 2025-06-16 16:34:46 -07:00
Bruce MacDonald
63e7634014 pr feedback 2025-06-16 16:08:38 -07:00
Bruce MacDonald
8d51d92f3b server: cache gguf model capabilities rather than reading off disc 2025-06-16 15:17:36 -07:00
Bruce MacDonald
2348fef568 Revert "server: model info caching system for improved performance"
This reverts commit 8ef643d4978168a8563ae24434a424358ce390e3.
2025-06-16 15:17:02 -07:00
Bruce MacDonald
883f655dd6 server: model info caching system for improved performance
Implements an in-memory cache for loaded models with file modification
time tracking to ensure cache validity. Models are now cached after
first load and retrieved from cache on subsequent requests if the
underlying manifest file hasn't changed.

Key changes:
- Add ModelCache with get/set methods and modification time validation
- Cache models in GetModel() and check cache before disk load
- Move capabilities calculation to model loading time and store in model
- Update capability access to use cached field instead of runtime calculation
- Add test coverage for cache behavior and model loading

This reduces redundant model loading operations and improves response
times for model access.
2025-06-16 15:16:58 -07:00
98 changed files with 1265 additions and 133473 deletions

View File

@@ -54,6 +54,48 @@ jobs:
name: build-${{ matrix.os }}-${{ matrix.arch }}
path: dist/*
darwin-sign:
runs-on: macos-13
environment: release
needs: darwin-build
steps:
- uses: actions/checkout@v4
- run: |
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
security create-keychain -p password build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p password build.keychain
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
security set-keychain-settings -lut 3600 build.keychain
env:
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
- uses: actions/download-artifact@v4
with:
name: build-darwin-amd64
path: dist/darwin-amd64
- uses: actions/download-artifact@v4
with:
name: build-darwin-arm64
path: dist/darwin-arm64
- run: |
export VERSION=${GITHUB_REF_NAME#v}
./scripts/build_darwin.sh sign macapp
env:
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
APPLE_ID: ${{ vars.APPLE_ID }}
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
- uses: actions/upload-artifact@v4
with:
name: dist-darwin
path: |
dist/Ollama-darwin.zip
dist/ollama-darwin.tgz
windows-depends:
strategy:
matrix:
@@ -61,18 +103,21 @@ jobs:
arch: [amd64]
preset: ['CPU']
include:
- os: windows
arch: amd64
preset: 'CUDA 11'
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
cuda-version: '11.3'
- os: windows
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8'
flags: ''
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -115,9 +160,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -136,9 +178,9 @@ jobs:
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
@@ -188,11 +230,61 @@ jobs:
go-version-file: go.mod
- run: |
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
- if: matrix.arch == 'arm64'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
- run: |
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
& .\scripts\build_windows.ps1 buildApp
env:
VCToolsRedistDir: stub
- uses: actions/upload-artifact@v4
with:
name: build-${{ matrix.os }}-${{ matrix.arch }}
path: |
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
windows-sign:
runs-on: windows-2022
environment: release
needs: [windows-depends, windows-build]
steps:
- uses: actions/checkout@v4
- uses: google-github-actions/auth@v2
with:
project_id: ollama
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
- run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
- uses: actions/download-artifact@v4
with:
pattern: build-windows-*
path: dist\
merge-multiple: true
- uses: actions/download-artifact@v4
with:
pattern: depends-windows-amd64-*
path: dist\windows-amd64\
merge-multiple: true
- run: |
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: |
dist\OllamaSetup.exe
dist\ollama-windows-*.zip
linux-build:
strategy:
@@ -225,26 +317,21 @@ jobs:
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
esac
done
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
echo "Manifests"
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
echo $ARCHIVE
cat $ARCHIVE
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
@@ -298,8 +385,8 @@ jobs:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
@@ -331,7 +418,7 @@ jobs:
latest=false
suffix=${{ matrix.suffix }}
images: |
${{ vars.DOCKER_REPO }}
ollama/ollama
tags: |
type=ref,enable=true,priority=600,prefix=pr-,event=pr
type=semver,pattern={{version}}
@@ -341,24 +428,56 @@ jobs:
path: ${{ runner.temp }}
merge-multiple: true
- run: |
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }}
# Trigger downstream release process
trigger:
runs-on: ubuntu-latest
environment: release
needs: [darwin-build, windows-build, windows-depends, linux-build]
needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]
runs-on: linux
environment: release
permissions:
contents: write
env:
GH_TOKEN: ${{ github.token }}
steps:
- uses: actions/checkout@v4
- name: Create or update Release for tag
- uses: actions/download-artifact@v4
with:
name: dist-darwin
path: dist
- uses: actions/download-artifact@v4
with:
name: dist-windows
path: dist
- uses: actions/download-artifact@v4
with:
pattern: dist-linux-*
path: dist
merge-multiple: true
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
- name: Create or update Release
run: |
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
echo "Looking for existing release for ${RELEASE_VERSION}"
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
if [ -n "$OLD_TAG" ]; then
@@ -372,12 +491,5 @@ jobs:
--generate-notes \
--prerelease
fi
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
gh release upload ${GITHUB_REF_NAME} dist/* --clobber

View File

@@ -36,7 +36,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
linux:
needs: [changes]
@@ -46,7 +46,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
@@ -78,11 +78,11 @@ jobs:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows
steps:
- run: |
@@ -102,7 +102,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
@@ -120,9 +120,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -136,8 +133,8 @@ jobs:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}"
env:

View File

@@ -78,13 +78,14 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
)
endif()
@@ -115,11 +116,7 @@ if(CMAKE_HIP_COMPILER)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCY_SET rocm
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
)
install(RUNTIME_DEPENDENCY_SET rocm
RUNTIME_DEPENDENCIES
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*"

View File

@@ -17,12 +17,20 @@
"name": "CUDA",
"inherits": [ "Default" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{
@@ -50,7 +58,6 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
}
@@ -71,6 +78,11 @@
"configurePreset": "CUDA",
"targets": [ "ggml-cuda" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],

View File

@@ -65,7 +65,7 @@ continuation of the sentence:
Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad
CONTRIBUTING: provide clairity on good commit messages, and bad
Bad Examples:

View File

@@ -7,13 +7,12 @@ ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -39,6 +38,15 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.3
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
@@ -90,21 +98,23 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
FROM ubuntu:20.04
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \

View File

@@ -1,6 +1,6 @@
<div align="center">
  <a href="https://ollama.com">
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a>
</div>
@@ -10,7 +10,7 @@ Get up and running with large language models.
### macOS
[Download](https://ollama.com/download/Ollama.dmg)
[Download](https://ollama.com/download/Ollama-darwin.zip)
### Windows
@@ -360,7 +360,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
@@ -409,8 +409,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
### Cloud
@@ -456,7 +454,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
### Apple Vision Pro
@@ -595,12 +592,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
### Supported backends
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.

View File

@@ -42,23 +42,6 @@ type Client struct {
func checkError(resp *http.Response, body []byte) error {
if resp.StatusCode < http.StatusBadRequest {
if len(body) == 0 {
return nil
}
// streams can contain error message even with StatusOK
var errorResponse struct {
Error string `json:"error,omitempty"`
}
if err := json.Unmarshal(body, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
return nil
}
@@ -230,9 +213,25 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
if err := checkError(response, bts); err != nil {
return err
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
}
}
if err := fn(bts); err != nil {

View File

@@ -143,7 +143,6 @@ type Message struct {
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
@@ -468,14 +467,13 @@ type ListModelResponse struct {
// ProcessModelResponse is a single model description in [ProcessResponse].
type ProcessModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
SizeVRAM int64 `json:"size_vram"`
ContextLength int `json:"context_length"`
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
SizeVRAM int64 `json:"size_vram"`
}
type TokenResponse struct {

View File

@@ -18,8 +18,6 @@ import (
const defaultPrivateKey = "id_ed25519"
var ErrInvalidToken = errors.New("invalid token")
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
@@ -29,39 +27,6 @@ func keyPath() (string, error) {
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func parseToken(token string) (key, sig []byte, _ error) {
keyData, sigData, ok := strings.Cut(token, ":")
if !ok {
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
}
sig, err := base64.StdEncoding.DecodeString(sigData)
if err != nil {
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
}
return []byte(keyData), sig, nil
}
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
keyShort, sigBytes, err := parseToken(token)
if err != nil {
return nil, err
}
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
if err != nil {
return nil, err
}
if err := pub.Verify([]byte(checkData), &ssh.Signature{
Format: pub.Type(),
Blob: sigBytes,
}); err != nil {
return nil, err
}
return pub, nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {

View File

@@ -1,254 +0,0 @@
package auth
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
type KeyEntry struct {
Name string
PublicKey string
Endpoints []string
}
type KeyPermission struct {
Name string
Endpoints []string
}
type APIPermissions struct {
permissions map[string]*KeyPermission
lastModified time.Time
mutex sync.RWMutex
}
var ws = regexp.MustCompile(`\s+`)
func authkeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "authorized_keys"), nil
}
func NewAPIPermissions() *APIPermissions {
return &APIPermissions{
permissions: make(map[string]*KeyPermission),
mutex: sync.RWMutex{},
}
}
func (ap *APIPermissions) ReloadIfNeeded() error {
ap.mutex.Lock()
defer ap.mutex.Unlock()
filename, err := authkeyPath()
if err != nil {
return err
}
fileInfo, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("failed to stat file: %v", err)
}
if !fileInfo.ModTime().After(ap.lastModified) {
return nil
}
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ap.lastModified = fileInfo.ModTime()
return ap.parse(file)
}
func (ap *APIPermissions) parse(r io.Reader) error {
ap.permissions = make(map[string]*KeyPermission)
scanner := bufio.NewScanner(r)
var cnt int
for scanner.Scan() {
cnt += 1
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
line = ws.ReplaceAllString(line, " ")
entry, err := ap.parseLine(line)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
continue
}
var pubKeyStr string
if entry.PublicKey == "*" {
pubKeyStr = "*"
} else {
pubKey, err := ap.validateAndDecodeKey(entry)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
continue
}
pubKeyStr = pubKey
}
if perm, exists := ap.permissions[pubKeyStr]; exists {
if perm.Name == "default" {
perm.Name = entry.Name
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
// skip redundant entries
continue
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
// overwrite redundant entries
perm.Endpoints = entry.Endpoints
} else {
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
}
} else {
ap.permissions[pubKeyStr] = &KeyPermission{
Name: entry.Name,
Endpoints: entry.Endpoints,
}
}
}
return scanner.Err()
}
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
parts := strings.SplitN(line, " ", 4)
if len(parts) < 2 {
return nil, fmt.Errorf("key type and public key not found")
}
kind, b64Key := parts[0], parts[1]
name := "default"
eps := "*"
if len(parts) >= 3 && parts[2] != "" {
if parts[2] != "*" {
name = parts[2]
}
}
if len(parts) == 4 && parts[3] != "" {
eps = parts[3]
}
if kind != "ssh-ed25519" && kind != "*" {
return nil, fmt.Errorf("unsupported key type %s", kind)
}
if kind == "*" && b64Key != "*" {
return nil, fmt.Errorf("unsupported key type")
}
var endpoints []string
if eps == "*" {
endpoints = []string{"*"}
} else {
for _, e := range strings.Split(eps, ",") {
e = strings.TrimSpace(e)
if e == "" {
return nil, fmt.Errorf("empty endpoint in list")
} else if e == "*" {
endpoints = []string{"*"}
break
}
endpoints = append(endpoints, e)
}
}
return &KeyEntry{
PublicKey: b64Key,
Name: name,
Endpoints: endpoints,
}, nil
}
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
pub, err := ssh.ParsePublicKey(keyBlob)
if err != nil {
return "", fmt.Errorf("parse key: %w", err)
}
if pub.Type() != ssh.KeyAlgoED25519 {
return "", fmt.Errorf("key is not Ed25519")
}
return entry.PublicKey, nil
}
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
if err := ap.ReloadIfNeeded(); err != nil {
return false, "unknown", err
}
ap.mutex.RLock()
defer ap.mutex.RUnlock()
if wildcardPerm, exists := ap.permissions["*"]; exists {
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
return true, wildcardPerm.Name, nil
}
for _, allowedEndpoint := range wildcardPerm.Endpoints {
if allowedEndpoint == endpoint {
return true, wildcardPerm.Name, nil
}
}
}
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
parts := strings.SplitN(keyString, " ", 2)
var base64Key string
if len(parts) > 1 {
base64Key = parts[1]
} else {
base64Key = parts[0]
}
base64Key = strings.TrimSpace(base64Key)
perm, exists := ap.permissions[base64Key]
if !exists {
return false, "unknown", nil
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
return true, perm.Name, nil
}
for _, allowedEndpoint := range perm.Endpoints {
if allowedEndpoint == endpoint {
return true, perm.Name, nil
}
}
return false, "unknown", nil
}

View File

@@ -1,133 +0,0 @@
package auth
import (
"bytes"
"reflect"
"testing"
)
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
func TestParse(t *testing.T) {
tests := []struct {
name string
file string
want map[string]*KeyPermission
}{
{
name: "two fields only defaults",
file: "ssh-ed25519 " + validB64 + "\n",
want: map[string]*KeyPermission{
validB64: {
Name: "default",
Endpoints: []string{"*"},
},
},
},
{
name: "extra whitespace collapsed and default endpoints",
file: "ssh-ed25519 " + validB64 + " alice\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "four fields full",
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
want: map[string]*KeyPermission{
validB64: {
Name: "bob",
Endpoints: []string{"/api/foo", "/api/bar"},
},
},
},
{
name: "comment lines ignored and multiple entries",
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
want: map[string]*KeyPermission{
validB64: {
Name: "user1",
Endpoints: []string{"*"},
},
},
},
{
name: "three entries variety",
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "two entries w/ wildcard",
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"/api/a"},
},
"*": {
Name: "default",
Endpoints: []string{"/api/b"},
},
},
},
{
name: "tags for everyone",
file: "* * * /api/tags",
want: map[string]*KeyPermission{
"*": {
Name: "default",
Endpoints: []string{"/api/tags"},
},
},
},
{
name: "default name",
file: "* * somename",
want: map[string]*KeyPermission{
"*": {
Name: "somename",
Endpoints: []string{"*"},
},
},
},
{
name: "unsupported key type",
file: "ssh-rsa AAAAB3Nza...\n",
want: map[string]*KeyPermission{},
},
{
name: "bad base64",
file: "ssh-ed25519 invalid@@@\n",
want: map[string]*KeyPermission{},
},
{
name: "just an asterix",
file: "*\n",
want: map[string]*KeyPermission{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
perms := NewAPIPermissions()
err := perms.parse(bytes.NewBufferString(tc.file))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(perms.permissions) != len(tc.want) {
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
}
if !reflect.DeepEqual(perms.permissions, tc.want) {
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
}
})
}
}

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -583,13 +583,12 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
} else {
until = format.HumanTime(m.ExpiresAt, "Never")
}
ctxStr := strconv.Itoa(m.ContextLength)
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
}
}
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"})
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderLine(false)
@@ -1080,11 +1079,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
role := "assistant"
fn := func(response api.ChatResponse) error {
if response.Message.Content != "" || !opts.HideThinking {
p.StopAndClear()
@@ -1137,14 +1135,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
if errors.Is(err, context.Canceled) {
return nil, nil
}
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err
}
@@ -1426,13 +1416,13 @@ func NewCLI() *cobra.Command {
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Short: "Create a model from a Modelfile",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
showCmd := &cobra.Command{

View File

@@ -385,21 +385,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified for this model.")
fmt.Println("No parameters were specified for this model.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
}
}
fmt.Println()
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf(" %-*s %v\n", 30, k, v)
}
fmt.Println()
fmt.Println("Model defined parameters:")
fmt.Println(resp.Parameters)
}
case "system":
switch {

View File

@@ -190,8 +190,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":

View File

@@ -1,165 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"gonum.org/v1/gonum/stat/distuv"
)
type gemma3nModel struct {
ModelParameters
TextModel struct {
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
AltupActiveIdx uint32 `json:"altup_active_idx"`
AltupCoefClip float32 `json:"altup_coef_clip"`
AltupCorrectScale bool `json:"altup_correct_scale"`
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
AltupNumInputs uint32 `json:"altup_num_inputs"`
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
IntermediateSize uint32 `json:"intermediate_size"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
} `json:"text_config"`
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
norm := distuv.Normal{Mu: 0, Sigma: 1}
for _, v := range m.TextModel.ActivationSparsityPattern {
if !yield(float32(norm.Quantile(float64(v)))) {
break
}
}
})
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, t := range m.TextModel.LayerTypes {
if !yield(t == "sliding_attention") {
break
}
}
})
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
return kv
}
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
out, ts := mergeTensors(ts,
merge{"altup_proj.*.weight", "altup_proj.weight"},
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
)
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "audio_tower"),
strings.Contains(t.Name(), "embed_audio"),
strings.Contains(t.Name(), "vision_tower"),
strings.Contains(t.Name(), "embed_vision"):
// TODO: handle audio and vision towers
continue
case strings.Contains(t.Name(), "altup_predict_coef"),
strings.Contains(t.Name(), "altup_correct_coef"):
if m.TextModel.AltupCoefClip > 0 {
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
if err != nil {
return nil, err
}
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gemma3nModel) Replacements() []string {
return []string{
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"post_feedforward_layernorm", "post_ffw_norm",
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
"altup.", "altup_",
"modality_router", "router",
"prediction_coefs", "predict_coef",
"correction_coefs", "correct_coef",
"correct_output_scale", "correct_scale.weight",
"laurel.", "laurel_",
"linear_left", "l",
"linear_right", "r",
"post_laurel_norm", "post_norm",
}
}

View File

@@ -2,6 +2,9 @@ package convert
import (
"fmt"
"io"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
@@ -27,38 +30,65 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
}
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
merges := make([]merge, 0, p.NumHiddenLayers*6)
for i := range p.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w1.bias", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
}, merge{
fmt.Sprintf("blk.%d.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w2.bias", i),
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
}, merge{
fmt.Sprintf("blk.%d.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w3.bias", i),
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
"w2", "ffn_down_exps",
"w3", "ffn_up_exps",
}
for i := range p.NumLocalExperts {
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
}
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
namer := strings.NewReplacer(oldnew...)
experts := make(map[string]experts)
// merge experts into a single tensor while removing them from ts
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
return false
}
name := namer.Replace(t.Name())
experts[name] = append(experts[name], t)
return true
})
var out []*ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, &ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
WriterTo: e,
})
}
out, ts := mergeTensors(ts, merges...)
return append(out, p.llamaModel.Tensors(ts)...)
}
func (p *mixtralModel) Replacements() []string {
return append(
p.llamaModel.Replacements(),
"model.layers", "blk",
"block_sparse_moe.gate", "ffn_gate_inp",
"block_sparse_moe.experts.", ".",
)
}
type experts []Tensor
func (e experts) WriteTo(w io.Writer) (int64, error) {
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
for _, t := range e {
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
// this accomplishes the same thing by writing each expert tensor in sequence
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -11,13 +11,14 @@ import (
"io"
"io/fs"
"log/slog"
"maps"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/fs/ggml"
)
@@ -136,7 +137,9 @@ func TestConvertModel(t *testing.T) {
t.Fatal(err)
}
for _, k := range slices.Sorted(maps.Keys(expect)) {
keys := maps.Keys(expect)
slices.Sort(keys)
for _, k := range keys {
if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k)
} else if v != expect[k] {
@@ -340,7 +343,9 @@ func TestConvertAdapter(t *testing.T) {
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
keys := maps.Keys(c.Expected)
slices.Sort(keys)
for _, k := range keys {
if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k)
} else if v != c.Expected[k] {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"io/fs"
"maps"
"slices"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/x448/float16"
"golang.org/x/exp/maps"
)
type safetensorMetadata struct {
@@ -46,7 +46,8 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
return nil, err
}
keys := slices.Sorted(maps.Keys(headers))
keys := maps.Keys(headers)
slices.Sort(keys)
names := make(map[string]struct{}, len(keys))

View File

@@ -2,9 +2,7 @@ package convert
import (
"cmp"
"io"
"iter"
"path"
"slices"
"strings"
@@ -76,54 +74,3 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
}
}
}
type merge struct {
pattern, name string
}
// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
var matched []Tensor
for i := range merges {
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
matched, _ := path.Match(merges[i].pattern, t.Name())
return matched
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,
Kind: matched[0].Kind(),
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
WriterTo: mergeGroup(matched),
})
}
}
return out, unmatched
}
// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
for _, e := range s {
if fn(e) {
matched = append(matched, e)
} else {
unmatched = append(unmatched, e)
}
}
return matched, unmatched
}
type mergeGroup []Tensor
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
for _, t := range g {
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -9,8 +9,6 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
)
@@ -304,99 +302,3 @@ func TestSplitDim(t *testing.T) {
}
})
}
func TestMerge(t *testing.T) {
unmatched := []Tensor{
&fakeTensor{
name: "a.0.b",
shape: []uint64{5, 2},
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
},
&fakeTensor{
name: "a.1.b",
shape: []uint64{5, 2},
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
},
&fakeTensor{
name: "c.0.d",
shape: []uint64{5, 2},
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
},
&fakeTensor{
name: "c.1.d",
shape: []uint64{5, 2},
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
},
&fakeTensor{
name: "e.0.f",
shape: []uint64{5, 2},
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
},
}
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
for i := range n {
got := matched[i]
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
t.Errorf("unexpected (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := got.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, 20)
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
offset := 10 + (i * 20)
want := make([]float32, 20)
for j := range 20 {
want[j] = float32(offset + j)
}
if diff := cmp.Diff(want, f32s); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
}
t.Run("single merge", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
if len(unmatched) != 3 {
t.Error("expected 3 remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
checkMatched(t, 1, matched)
})
t.Run("multiple merges", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
if len(unmatched) != 1 {
t.Error("expected 1 remaining tensors, got", len(unmatched))
}
if len(matched) != 2 {
t.Error("expected 2 merged tensor, got", len(matched))
}
checkMatched(t, 2, matched)
})
t.Run("no match", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
if len(unmatched) != 5 {
t.Error("expected 5 remaining tensors, got", len(unmatched))
}
if len(matched) != 0 {
t.Error("expected no merged tensors, got", len(matched))
}
})
}

View File

@@ -8,10 +8,11 @@ import (
"fmt"
"io/fs"
"log/slog"
"maps"
"os"
"slices"
"strings"
"golang.org/x/exp/maps"
)
const (
@@ -259,8 +260,11 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
tokens[token.ID] = token
}
keys := maps.Keys(tokens)
slices.Sort(keys)
v := Vocabulary{Model: "gpt2"}
for _, k := range slices.Sorted(maps.Keys(tokens)) {
for _, k := range keys {
token := tokens[k]
v.Tokens = append(v.Tokens, token.Content)
v.Scores = append(v.Scores, float32(token.ID))

View File

@@ -3,7 +3,6 @@
package discover
import (
"fmt"
"log/slog"
"os"
"regexp"
@@ -56,13 +55,10 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
}
return "sbsa"
}
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11"
}
return "v12"

View File

@@ -12,7 +12,7 @@ import (
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {

View File

@@ -4,7 +4,6 @@
* [Quickstart](../README.md#quickstart)
* [Examples](./examples.md)
* [Importing models](./import.md)
* [MacOS Documentation](./macos.md)
* [Linux Documentation](./linux.md)
* [Windows Documentation](./windows.md)
* [Docker Documentation](./docker.md)

View File

@@ -500,30 +500,21 @@ The `message` object has the following fields:
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Tool calling
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
### Examples
#### Chat request (Streaming)
#### Chat Request (Streaming)
##### Request
@@ -578,88 +569,6 @@ Final response:
}
```
#### Chat request (Streaming with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": true
}'
```
##### Response
A stream of JSON objects is returned:
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:22:19.184789Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done": false
}
```
Final response:
```json
{
"model":"llama3.2",
"created_at":"2025-07-07T20:22:19.19314Z",
"message": {
"role": "assistant",
"content": ""
},
"done_reason": "stop",
"done": true,
"total_duration": 182242375,
"load_duration": 41295167,
"prompt_eval_count": 169,
"prompt_eval_duration": 24573166,
"eval_count": 15,
"eval_duration": 115959084
}
```
#### Chat request (No streaming)
##### Request
@@ -697,74 +606,6 @@ curl http://localhost:11434/api/chat -d '{
}
```
#### Chat request (No streaming, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": false
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:32:53.844124Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 3244883583,
"load_duration": 2969184542,
"prompt_eval_count": 169,
"prompt_eval_duration": 141656333,
"eval_count": 18,
"eval_duration": 133293625
}
```
#### Chat request (Structured outputs)
##### Request
@@ -871,87 +712,6 @@ Final response:
}
```
#### Chat request (With history, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in Toronto?"
},
// the message from the model appended to history
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_temperature",
"arguments": {
"city": "Toronto"
}
},
}
]
},
// the tool call result appended to history
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_temperature",
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
]
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:43:37.688511Z",
"message": {
"role": "assistant",
"content": "The current temperature in Toronto is 11°C."
},
"done_reason": "stop",
"done": true,
"total_duration": 890771750,
"load_duration": 707634750,
"prompt_eval_count": 94,
"prompt_eval_duration": 91703208,
"eval_count": 11,
"eval_duration": 90282125
}
```
#### Chat request (with images)
##### Request

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -118,7 +118,7 @@ To run tests, use `go test`:
go test ./...
```
> NOTE: In rare circumstances, you may need to change a package using the new
> NOTE: In rare cirumstances, you may need to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or

View File

@@ -292,7 +292,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
## How does Ollama handle concurrent requests?
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing.
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
@@ -301,7 +301,7 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
@@ -333,16 +333,3 @@ The currently available K/V cache quantization types are:
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
## How can I stop Ollama from starting when I login to my computer
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
**Windows**
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
**MacOS Monterey (v12)**
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
**MacOS Ventura (v13) and later**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -1,14 +1,12 @@
# GPU
## Nvidia
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
Ollama supports Nvidia GPUs with compute capability 5.0+.
Check your compute compatibility to see if your card is supported:
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |

View File

@@ -53,8 +53,6 @@ FROM /path/to/safetensors/directory
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
Now run the `ollama create` command from the directory where you created the `Modelfile`:
```shell

View File

@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
Download and extract the package:
```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
```

View File

@@ -1,42 +0,0 @@
# Ollama for macOS
## System Requirements
* MacOS Monterey (v12) or newer
* Apple M series (CPU and GPU support) or x86 (CPU only)
## Filesystem Requirements
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
### Changing Install Location
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
## Troubleshooting
Ollama on MacOS stores files in a few different locations.
- `~/.ollama` contains models and configuration
- `~/.ollama/logs` contains logs
- *app.log* contains most recent logs from the GUI application
- *server.log* contains the most recent server logs
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
## Uninstall
To fully remove Ollama from your system, remove the following files and folders:
```
sudo rm -rf /Applications/Ollama.app
sudo rm /usr/local/bin/ollama
rm -rf "~/Library/Application Support/Ollama"
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
rm -rf ~/Library/Caches/com.electron.ollama/
rm -rf ~/Library/Caches/ollama
rm -rf ~/Library/WebKit/com.electron.ollama
rm -rf ~/.ollama
```

View File

@@ -150,7 +150,7 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |

View File

@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
# Define the schema for the response
class FriendInfo(BaseModel):
name: str
age: int
age: int
is_available: bool
class FriendList(BaseModel):

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager --follow --pager-end
journalctl -u ollama --no-pager --follow --pager-end
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
@@ -23,7 +23,7 @@ docker logs <container-name>
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
@@ -38,12 +38,12 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
## LLM libraries
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
In the server log, you will see a message that looks something like this (varies from release to release):
```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
```
**Experimental LLM Library Override**
@@ -97,7 +97,7 @@ If none of those resolve the problem, gather additional information and file an
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems

View File

@@ -30,6 +30,20 @@ To install the Ollama application in a location different than your home directo
OllamaSetup.exe /DIR="d:\some\location"
```
### Changing Model Location
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
2. Click on _Edit environment variables for your account_.
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
4. Click OK/Apply to save.
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
## API Access
Here's a quick example showing API access from `powershell`

View File

@@ -219,7 +219,7 @@ func Uint(key string, defaultValue uint) func() uint {
var (
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1)
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.

View File

@@ -10,5 +10,4 @@ type Config interface {
Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
}

View File

@@ -34,8 +34,7 @@ func (kv KV) Kind() string {
}
func (kv KV) ParameterCount() uint64 {
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
return val
return keyValue(kv, "general.parameter_count", uint64(0))
}
func (kv KV) FileType() FileType {
@@ -54,27 +53,16 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
func (kv KV) HeadCount() uint64 {
return uint64(kv.Uint("attention.head_count"))
}
func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
func (kv KV) HeadCountKV() uint64 {
return uint64(kv.Uint("attention.head_count_kv", 1))
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
func (kv KV) HeadCountKVMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
}
func (kv KV) EmbeddingHeadCountMax() uint64 {
if heads := kv.HeadCountMin(); heads > 0 {
func (kv KV) EmbeddingHeadCount() uint64 {
if heads := kv.HeadCount(); heads > 0 {
return kv.EmbeddingLength() / heads
}
@@ -82,11 +70,15 @@ func (kv KV) EmbeddingHeadCountMax() uint64 {
}
func (kv KV) EmbeddingHeadCountK() uint64 {
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) EmbeddingHeadCountV() uint64 {
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) GQA() uint64 {
return kv.HeadCount() / kv.HeadCountKV()
}
func (kv KV) ContextLength() uint64 {
@@ -98,83 +90,40 @@ func (kv KV) ChatTemplate() string {
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
return keyValue(kv, key, append(defaultValue, "")...)
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
_, max := kv.UintOrArrayValue(key, defaultValue)
return max
}
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
min, _ := kv.UintOrArrayValue(key, defaultValue)
return min
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return uint32(min), uint32(max)
}
return defaultValue, defaultValue
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
return val.values
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
return val.values
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
return val.values
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
return val.values
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
return val.values
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"gemma3n",
"mistral3",
"llama4",
"mllama",
@@ -194,17 +143,17 @@ type arrayValueTypes interface {
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
if val, ok := kv[key]; ok {
return val.(T)
}
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
return defaultValue[0], false
slog.Debug("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0]
}
type Tensors struct {
@@ -476,11 +425,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsKV := f.KV().HeadCountKVMax()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
@@ -555,7 +504,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2", "gemma3", "gemma3n":
case "gemma", "gemma2", "gemma3":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -568,11 +517,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding*embeddingHeadsK*heads*9/16,
)
if f.KV().Architecture() == "gemma3n" {
fullOffload *= 4
partialOffload *= 4
}
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {

View File

@@ -269,33 +269,3 @@ func TestKeyValue(t *testing.T) {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}
func TestHeadCount(t *testing.T) {
valuesArray := []int32{1, 5, 3, 4}
cases := []struct {
kv KV
want uint64
}{
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
},
want: uint64(5),
},
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": uint32(3),
},
want: uint64(3),
},
}
for _, tt := range cases {
got := tt.kv.HeadCountMax()
if got != tt.want {
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
}
}
}

View File

@@ -609,10 +609,6 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values)
case []bool:
err = writeGGUFArray(ws, ggufTypeBool, v)
case *array[bool]:
err = writeGGUFArray(ws, ggufTypeBool, v.values)
default:
return fmt.Errorf("improper type for '%s'", k)
}

View File

@@ -65,7 +65,7 @@ func Open(path string) (f *File, err error) {
return nil, err
}
if f.Version < 2 {
if f.Version != 3 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}

4
go.mod
View File

@@ -25,7 +25,6 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
)
require (
@@ -45,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
@@ -71,7 +71,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0

View File

@@ -1,57 +0,0 @@
//go:build integration && library
package integration
import (
"context"
"log/slog"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// First run of this scenario on a target system will take a long time to download
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
func TestLibraryModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
chatModels := libraryChatModels
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0.1,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
// Special cases
if model == "duckdb-nsql" {
anyResp = []string{"select", "from"}
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
anyResp = []string{"yes", "no", "safe", "unsafe"}
} else if model == "openthinker" || model == "nexusraven" {
anyResp = []string{"plugin", "im_sep", "components", "function call"}
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
req.Prompt = "def fibonacci():"
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}

View File

@@ -19,6 +19,35 @@ import (
"github.com/ollama/ollama/format"
)
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
}
)
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
@@ -39,13 +68,6 @@ func TestModelsGenerate(t *testing.T) {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {

View File

@@ -1,266 +0,0 @@
//go:build integration && perf
package integration
import (
"context"
"fmt"
"io/ioutil"
"log/slog"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
// Models that don't work reliably with the large context prompt in this test case
longContextFlakes = []string{
"granite-code:latest",
"nemotron-mini:latest",
"falcon:latest", // 2k model
"falcon2:latest", // 2k model
"minicpm-v:latest",
"qwen:latest",
"solar-pro:latest",
}
)
// Note: this test case can take a long time to run, particularly on models with
// large contexts. Run with -timeout set to a large value to get reasonable coverage
// Example usage:
//
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
longPrompt := "summarize the following: " + string(data)
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
var maxContext int
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
t.Fatalf("show failed: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
// For these tests we want to exercise a some amount of overflow on the CPU
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
slog.Info("scneario", "model", model, "max_context", maxContext)
loaded := false
defer func() {
// best effort unload once we're done with the model
if loaded {
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}
}()
// Some models don't handle the long context data well so skip them to avoid flaky test results
longContextFlake := false
for _, flake := range longContextFlakes {
if model == flake {
longContextFlake = true
break
}
}
// iterate through a few context sizes for coverage without excessive runtime
var contexts []int
keepGoing := true
if maxContext > 16384 {
contexts = []int{4096, 8192, 16384, maxContext}
} else if maxContext > 8192 {
contexts = []int{4096, 8192, maxContext}
} else if maxContext > 4096 {
contexts = []int{4096, maxContext}
} else if maxContext > 0 {
contexts = []int{maxContext}
} else {
t.Fatal("unknown max context size")
}
for _, numCtx := range contexts {
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
break
}
skipLongPrompt := false
// Workaround bug 11172 temporarily...
maxPrompt := longPrompt
// If we fill the context too full with the prompt, many models
// quickly hit context shifting and go bad.
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
maxPrompt = maxPrompt[:numCtx*2]
}
testCases := []struct {
prompt string
anyResp []string
}{
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
}
var gpuPercent int
for _, tc := range testCases {
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
continue
}
req := api.GenerateRequest{
Model: model,
Prompt: tc.prompt,
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": numCtx,
},
}
atLeastOne := false
var resp api.GenerateResponse
stream := false
req.Stream = &stream
// Avoid potentially getting stuck indefinitely
limit := 5 * time.Minute
genCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(limit),
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
)
defer cancel()
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
resp = rsp
return nil
})
if err != nil {
// Avoid excessive test runs, but don't consider a failure with massive context
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
slog.Warn("max context was taking too long, skipping", "error", err)
keepGoing = false
skipLongPrompt = true
continue
}
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
}
loaded = true
for _, expResp := range tc.anyResp {
if strings.Contains(strings.ToLower(resp.Response), expResp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
}
models, err := client.ListRunning(ctx)
if err != nil {
slog.Warn("failed to list running models", "error", err)
continue
}
if len(models.Models) > 1 {
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
}
for _, m := range models.Models {
if m.Name == model {
if m.SizeVRAM == 0 {
slog.Info("Model fully loaded into CPU")
gpuPercent = 0
keepGoing = false
skipLongPrompt = true
} else if m.SizeVRAM == m.Size {
slog.Info("Model fully loaded into GPU")
gpuPercent = 100
} else {
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
gpuPercent = int(100 - cpuPercent)
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
keepGoing = false
// Heuristic to avoid excessive test run time
if gpuPercent < 90 {
skipLongPrompt = true
}
}
}
}
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL",
"CONTEXT",
"GPU PERCENT",
"PROMPT COUNT",
"LOAD TIME",
"PROMPT EVAL TPS",
"EVAL TPS",
)
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
model,
numCtx,
gpuPercent,
resp.PromptEvalCount,
float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
)
}
}
})
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -32,229 +32,6 @@ const (
smol = "llama3.2:1b"
)
var (
started = time.Now()
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"gemma3n:e2b",
"mistral-small3.2:latest",
"deepseek-r1:1.5b",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen2.5vl:3b",
"qwen3:0.6b", // dense
"qwen3:30b", // MOE
"gemma3:1b",
"llama3.1:latest",
"llama3.2:latest",
"gemma2:latest",
"minicpm-v:latest", // arch=qwen2
"granite-code:latest", // arch=llama
}
llamaRunnerChatModels = []string{
"mistral:latest",
"falcon3:latest",
"granite3-moe:latest",
"command-r:latest",
"nemotron-mini:latest",
"phi3.5:latest",
"solar-pro:latest",
"internlm2:latest",
"codellama:latest", // arch=llama
"phi3:latest",
"falcon2:latest",
"gemma:latest",
"llama2:latest",
"nous-hermes:latest",
"orca-mini:latest",
"qwen:latest",
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
}
// Some library models are quite large - ensure large VRAM and sufficient disk space
// before running scenarios based on this set
libraryChatModels = []string{
"alfred",
"athene-v2",
"aya-expanse",
"aya",
"bakllava",
"bespoke-minicheck",
"codebooga",
"codegeex4",
"codegemma",
"codellama",
"codeqwen",
"codestral",
"codeup",
"cogito",
"command-a",
"command-r-plus",
"command-r",
"command-r7b-arabic",
"command-r7b",
"dbrx",
"deepcoder",
"deepscaler",
"deepseek-coder-v2",
"deepseek-coder",
"deepseek-llm",
"deepseek-r1",
// "deepseek-v2.5", // requires 155 GB VRAM
"deepseek-v2",
// "deepseek-v3", // requires 482 GB VRAM
"devstral",
"dolphin-llama3",
"dolphin-mistral",
"dolphin-mixtral",
"dolphin-phi",
"dolphin3",
"dolphincoder",
"duckdb-nsql",
"everythinglm",
"exaone-deep",
"exaone3.5",
"falcon",
"falcon2",
"falcon3",
"firefunction-v2",
"gemma",
"gemma2",
"gemma3",
"gemma3n",
"glm4",
"goliath",
"granite-code",
"granite3-dense",
"granite3-guardian",
"granite3-moe",
"granite3.1-dense",
"granite3.1-moe",
"granite3.2-vision",
"granite3.2",
"granite3.3",
"hermes3",
"internlm2",
"llama-guard3",
"llama-pro",
"llama2-chinese",
"llama2-uncensored",
"llama2",
"llama3-chatqa",
"llama3-gradient",
"llama3-groq-tool-use",
"llama3.1",
"llama3.2-vision",
"llama3.2",
"llama3.3",
"llama3",
"llama4",
"llava-llama3",
"llava-phi3",
"llava",
"magicoder",
"magistral",
"marco-o1",
"mathstral",
"meditron",
"medllama2",
"megadolphin",
"minicpm-v",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
"mistral-small",
"mistral-small3.1",
"mistral-small3.2",
"mistral",
"mistrallite",
"mixtral",
"moondream",
"nemotron-mini",
"nemotron",
"neural-chat",
"nexusraven",
"notus",
"nous-hermes",
"nous-hermes2-mixtral",
"nous-hermes2",
"nuextract",
"olmo2",
"open-orca-platypus2",
"openchat",
"opencoder",
"openhermes",
"openthinker",
"orca-mini",
"orca2",
// "phi", // unreliable
"phi3.5",
"phi3",
"phi4-mini-reasoning",
"phi4-mini",
"phi4-reasoning",
"phi4",
"phind-codellama",
"qwen",
"qwen2-math",
"qwen2.5-coder",
"qwen2.5",
"qwen2.5vl",
"qwen2",
"qwen3:0.6b", // dense
"qwen3:30b", // MOE
"qwq",
"r1-1776",
"reader-lm",
"reflection",
"sailor2",
"samantha-mistral",
"shieldgemma",
"smallthinker",
"smollm",
"smollm2",
"solar-pro",
"solar",
"sqlcoder",
"stable-beluga",
"stable-code",
"stablelm-zephyr",
"stablelm2",
"starcoder",
"starcoder2",
"starling-lm",
"tinydolphin",
"tinyllama",
"tulu3",
"vicuna",
"wizard-math",
"wizard-vicuna-uncensored",
"wizard-vicuna",
"wizardcoder",
"wizardlm-uncensored",
"wizardlm2",
"xwinlm",
"yarn-llama2",
"yarn-mistral",
"yi-coder",
"yi",
"zephyr",
}
libraryEmbedModels = []string{
"all-minilm",
"bge-large",
"bge-m3",
"granite-embedding",
"mxbai-embed-large",
"nomic-embed-text",
"paraphrase-multilingual",
"snowflake-arctic-embed",
"snowflake-arctic-embed2",
}
)
func Init() {
lifecycle.InitLogging()
}
@@ -494,10 +271,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return
}
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
response := buf.String()

View File

@@ -25,9 +25,6 @@ type Causal struct {
opts CausalOptions
// maxBatch is the largest batch that we might receive
maxBatch int
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
@@ -150,7 +147,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.DType = dtype
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.maxBatch = maxBatch
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -643,64 +639,48 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
seqRange := c.cellRanges[seq]
size := seqRange.max - seqRange.min + 1
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
size := min(seqRange.max-start+1, c.maxBatch)
offsets := make([]int32, size)
offsets := make([]int32, size)
for i := range offsets {
cell := c.cells[seqRange.min+i]
var batchFirst, batchLast int
batchFirst = -1
for i := range offsets {
cell := c.cells[start+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
if batchFirst < 0 {
batchFirst = i
}
batchLast = i
}
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
}
}
if batchFirst < 0 {
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
for i, key := range c.keys {
if key == nil {
continue
}
offsets = offsets[batchFirst : batchLast+1]
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
ctx := c.backend.NewContext()
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*(start+batchFirst),
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
len(offsets),
)
roped, err := c.shiftFn(ctx, i, key, kShift)
if err != nil {
ctx.Close()
return err
}
ctx.Forward(roped.Copy(ctx, key))
roped, err := c.shiftFn(ctx, i, key, kShift)
if err != nil {
return err
}
ctx.Compute()
ctx.Close()
ctx.Forward(roped.Copy(ctx, key))
}
ctx.Compute()
return nil
}

View File

@@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
llama_model_loader::llama_model_loader(
const std::string & fname,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 3a4e72a3..db62973f 100644
index 3a4e72a3..831b68c0 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {

View File

@@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
4 files changed, 59 insertions(+), 79 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index dca22d8b..1f3a3956 100644
index c22687e4..c5948e8f 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
if (!kv_self->find_slot(ubatch)) {
@@ -41,7 +41,7 @@ index dca22d8b..1f3a3956 100644
}
ggml_backend_sched_reset(sched.get());
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {

View File

@@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
3 files changed, 192 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 955fec59..654e2f28 100644
index becdae07..7a44b6cf 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
@@ -59,7 +59,7 @@ index 955fec59..654e2f28 100644
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
@@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index d027271f..4abd01d7 100644
index 2d46176e..47383486 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
@@ -257,7 +257,7 @@ index d027271f..4abd01d7 100644
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@@ -266,7 +266,7 @@ index d027271f..4abd01d7 100644
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;

View File

@@ -7,31 +7,31 @@ This enables matching up devices and information reported by the backend
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-cuda/ggml-cuda.cu | 39 ++++++++++++++++++++++++++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal.m | 1 +
3 files changed, 41 insertions(+)
3 files changed, 35 insertions(+)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 74e46716..48839339 100644
index 74e46716..a880df33 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -152,6 +152,7 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
+ const char * id;
+ const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index cb0d8528..d6960174 100644
index cb0d8528..4c829153 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
+ std::string id;
+ std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -39,9 +39,9 @@ index cb0d8528..d6960174 100644
return ctx->description.c_str();
}
+static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
+static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ctx->id.c_str();
+ return ctx->uuid.c_str();
+}
+
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
@@ -51,17 +51,17 @@ index cb0d8528..d6960174 100644
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
+ props->id = ggml_backend_cuda_device_get_id(dev);
+ props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3458,6 +3465,38 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
+ #if !defined(GGML_USE_HIP)
+ char id[64];
+ snprintf(id, sizeof(id),
+ char uuid[64];
+ snprintf(uuid, sizeof(uuid),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ (unsigned char)prop.uuid.bytes[0],
+ (unsigned char)prop.uuid.bytes[1],
@@ -80,29 +80,23 @@ index cb0d8528..d6960174 100644
+ (unsigned char)prop.uuid.bytes[14],
+ (unsigned char)prop.uuid.bytes[15]
+ );
+ dev_ctx->id = id;
+ dev_ctx->uuid = uuid;
+ #else
+ #ifdef _WIN32
+ char id[16];
+ snprintf(id, sizeof(id), "%d", i);
+ dev_ctx->id = id;
+ #else
+ dev_ctx->id = "GPU-" + std::string(prop.uuid.bytes, 16);
+ #endif
+ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
+ #endif
+
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 1b56f858..a9eeebc6 100644
index 1b56f858..ee4f2dcb 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
+ props->id = "0";
+ props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -1,32 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Sun, 22 Jun 2025 09:22:05 -0700
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
---
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 4e67d243..8f49f084 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent, dir_path);
- ggml_backend_load_best("cuda", silent, dir_path);
- ggml_backend_load_best("hip", silent, dir_path);
+
+ // Avoid mixed hip+cuda configurations
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
+ if (!hip_devices && !rocr_devices) {
+ ggml_backend_load_best("cuda", silent, dir_path);
+ } else {
+ ggml_backend_load_best("hip", silent, dir_path);
+ }
+
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);

View File

@@ -1,169 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Thu, 19 Jun 2025 08:05:21 +0300
Subject: [PATCH] metal : add mean kernel (#14267)
* metal : add mean kernel
ggml-ci
* cont : dedup implementation
ggml-ci
---
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
2 files changed, 67 insertions(+), 14 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index a9eeebc6..110c9ece 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+ GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (dst->op) {
+ case GGML_OP_SUM_ROWS:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ break;
+ case GGML_OP_MEAN:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ }
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+ nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 9cfddf45..08e8d807 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -956,31 +956,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
+template <bool norm>
kernel void kernel_sum_rows(
+ constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
- constant ggml_metal_kargs_sum_rows & args,
- uint3 tpig[[thread_position_in_grid]]) {
- int64_t i3 = tpig.z;
- int64_t i2 = tpig.y;
- int64_t i1 = tpig.x;
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ int64_t i3 = tgpig.z;
+ int64_t i2 = tgpig.y;
+ int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
- float row_sum = 0;
+ float sumf = 0;
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
- row_sum += src_row[i0];
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+ sumf += src_row[i0];
}
- dst_row[0] = row_sum;
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ if (tpitg.x == 0) {
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
+ }
}
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
+
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Oliver Simons <osimons@nvidia.com>
Date: Tue, 22 Jul 2025 11:02:28 +0200
Subject: [PATCH] Enable CUDA Graphs for gemma3n.
Similar to
https://github.com/ggml-org/llama.cpp/pull/14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.
---
ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 2b9fabf4..28ccf4be 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
+ const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
+ const std::string gemma3n_node_name = "node_";
+
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
- // disable CUDA graphs for batch size > 1 for now.
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
+ // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
+ // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
+ && node->ne[2] == 1
+ && node->ne[3] == 1
+ && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
+ && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+ GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}

View File

@@ -151,12 +151,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
if graphPartialOffload == 0 {
headsKV := f.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := f.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
graphPartialOffload = f.KV().GQA() * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload

View File

@@ -139,13 +139,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
gpus = discover.GetCPUInfo()
}
// Verify the requested context size is <= the model training size
trainCtx := f.KV().ContextLength()
if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 {
slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx)
opts.NumCtx = int(trainCtx) * numParallel
}
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
if len(gpus) > 1 || gpus[0].Library != "cpu" {
switch {
@@ -318,7 +311,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
params = append(params, "--mmproj", projectors[0])
}
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
// without any LD_LIBRARY_PATH flags
for {

View File

@@ -124,9 +124,9 @@ type DeviceMemory struct {
// may not be persistent across instances of the runner.
Name string
// ID is an identifier for the device for matching with system
// management libraries.
ID string
// UUID is a unique persistent identifier for the device for matching
// with system management libraries
UUID string
// Weights is the per-layer memory needed for the model weights.
Weights []Memory
@@ -156,8 +156,8 @@ func (m DeviceMemory) LogValue() slog.Value {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
if len(attrs) > 0 && m.UUID != "" {
attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...)
}
return slog.GroupValue(attrs...)
@@ -253,7 +253,6 @@ type Tensor interface {
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Div(ctx Context, t2 Tensor) Tensor
@@ -277,7 +276,6 @@ type Tensor interface {
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
@@ -299,12 +297,6 @@ type Tensor interface {
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
Mean(ctx Context) Tensor
Variance(ctx Context) Tensor
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -138,7 +138,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
requiredMemory.CPU.ID = C.GoString(props.id)
requiredMemory.CPU.UUID = C.GoString(props.uuid)
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
@@ -155,7 +155,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(d, &props)
requiredMemory.GPUs[i].ID = C.GoString(props.id)
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
}
@@ -297,9 +297,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
}
case contains(t.Name, "cls", "output", "output_norm",
"altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
@@ -355,26 +353,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
bbs[c] = b
}
// Mimic llama runner logs summarizing layers and memory
gpuLayers := 0
for _, layer := range layers {
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
gpuLayers++
}
}
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers))
switch C.ggml_backend_dev_type(output.d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
slog.Info("offloading output layer to CPU")
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
slog.Info("offloading output layer to GPU")
gpuLayers++
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
slog.Info("offloading output layer to ACCEL")
}
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1))
for bs := range maps.Values(bbs) {
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
@@ -420,7 +398,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(false),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
C._Bool(false),
),
schedBackends: schedBackends,
@@ -624,9 +602,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
needSync := true
@@ -915,13 +891,6 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
@@ -1229,13 +1198,6 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1311,42 +1273,3 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
Sqr(ctx).
SumRows(ctx).
Scale(ctx, 1/float64(t.Dim(0)))
}
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
return t.Variance(ctx).Sqrt(ctx)
}
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}

View File

@@ -152,7 +152,7 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
const char * id;
const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;

View File

@@ -573,16 +573,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent, dir_path);
// Avoid mixed hip+cuda configurations
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
if (!hip_devices && !rocr_devices) {
ggml_backend_load_best("cuda", silent, dir_path);
} else {
ggml_backend_load_best("hip", silent, dir_path);
}
ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);

View File

@@ -362,26 +362,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll

View File

@@ -35,7 +35,6 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
@@ -2323,9 +2322,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
@@ -2474,9 +2470,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
const std::string gemma3n_node_name = "node_";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2498,17 +2491,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
&& node->ne[2] == 1
&& node->ne[3] == 1
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
@@ -2896,7 +2884,7 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string id;
std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2909,9 +2897,9 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str();
}
static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->id.c_str();
return ctx->uuid.c_str();
}
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
@@ -2928,7 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3223,7 +3211,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
@@ -3479,8 +3466,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->description = prop.name;
#if !defined(GGML_USE_HIP)
char id[64];
snprintf(id, sizeof(id),
char uuid[64];
snprintf(uuid, sizeof(uuid),
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
(unsigned char)prop.uuid.bytes[0],
(unsigned char)prop.uuid.bytes[1],
@@ -3499,15 +3486,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
(unsigned char)prop.uuid.bytes[14],
(unsigned char)prop.uuid.bytes[15]
);
dev_ctx->id = id;
dev_ctx->uuid = uuid;
#else
#ifdef _WIN32
char id[16];
snprintf(id, sizeof(id), "%d", i);
dev_ctx->id = id;
#else
dev_ctx->id = "GPU-" + std::string(prop.uuid.bytes, 16);
#endif
dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
#endif
ggml_backend_dev_t dev = new ggml_backend_device {

View File

@@ -1,19 +0,0 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View File

@@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -1,9 +1,25 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -19,8 +35,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}

View File

@@ -1,4 +1,5 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -3434,61 +3434,31 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
float row_sum = 0;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
dst_row[0] = row_sum;
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@@ -489,7 +489,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1437,7 +1436,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1636,7 +1634,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2365,30 +2362,11 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
switch (dst->op) {
case GGML_OP_SUM_ROWS:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
break;
case GGML_OP_MEAN:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2418,12 +2396,11 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
@@ -5726,7 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
props->id = "0";
props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -956,61 +956,31 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
float row_sum = 0;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
dst_row[0] = row_sum;
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@@ -4,6 +4,6 @@ package metal
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_USE_BF16 -I.. -I../../include
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
// #cgo LDFLAGS: -framework Metal -framework MetalKit
import "C"

View File

@@ -1,51 +0,0 @@
package gemma3n
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*TextModel
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.TextModel.Forward(ctx, batch, m.Cache)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewCausalCache(m.Shift),
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
)
return &m, nil
}
func init() {
model.Register("gemma3n", New)
}

View File

@@ -1,359 +0,0 @@
package gemma3n
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
type TextModel struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
*PerLayerProjector
AltupEmbd *nn.Linear `gguf:"altup_proj"`
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
TextLayers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
TextOptions
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
// Create a tensor of a single float32 value of 1.0 to use for altup correction
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates := inputs.Concat(ctx, altupProj, 2)
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
for i, layer := range m.TextLayers {
if i < firstSharedKeyValue {
cache.SetLayer(i)
} else if m.isLocal(i) {
cache.SetLayer(firstSharedKeyValue - 2)
} else {
cache.SetLayer(firstSharedKeyValue - 1)
}
var layerType int
ropeBase := m.ropeBase
if m.isLocal(i) {
layerType = 1
ropeBase = m.ropeBaseLocal
}
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :]
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
}
// hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.ropeBase
if m.isLocal(layer) {
ropeBase = m.ropeBaseLocal
}
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
*nn.Embedding
}
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
}
type PerLayerProjector struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
Projector *nn.Linear `gguf:"per_layer_model_proj"`
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
}
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
perLayerProjection := p.Projector.Forward(ctx, inputs)
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
if inputsPerLayer != nil {
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
}
return perLayerProjection
}
type TextLayer struct {
*AltUp
*Laurel
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *TextAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
PerLayerProjection *nn.Linear `gguf:"proj"`
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
}
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
predictions := d.Predict(ctx, hiddenStates, opts)
active := opts.altupActive(ctx, predictions)
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
laurel := d.Laurel.Forward(ctx, attn, opts)
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
attn = active.Add(ctx, attn)
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
mlp = attn.Add(ctx, mlp)
predictions = d.Correct(ctx, predictions, mlp, one, opts)
active = opts.altupActive(ctx, predictions)
if opts.altupCorrectScale {
active = d.ScaleCorrectedOutput(ctx, active)
}
active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx)
active = active.Mul(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
return predictions0.Concat(ctx, active, 2)
}
type AltUp struct {
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
Router *nn.Linear `gguf:"altup_router"`
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
}
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
}
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
return predictions.Add(ctx, hiddenStates)
}
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
modalities := a.computeRouterModalities(ctx, activated, opts)
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Add(ctx, one)
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
corrected := innovation.Mul(ctx, coefficients)
corrected = corrected.Add(ctx, predictions)
return corrected
}
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
return predictions.Mul(ctx, a.CorrectionScale)
}
type Laurel struct {
LinearLeft *nn.Linear `gguf:"laurel_l"`
LinearRight *nn.Linear `gguf:"laurel_r"`
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
}
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates.Add(ctx, residual)
}
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
upStates := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
if activationSparsityScale > 0 {
mean := hiddenStates.Mean(ctx)
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
cutoff := mean.Add(ctx, std)
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
type TextOptions struct {
hiddenLayers int
hiddenSize int
hiddenSizePerLayerInput int
numHeads, numKVHeads int
keyLength, valueLength int
sharedKeyValueLayers int
altupActiveIndex int
altupInputs int
altupCorrectScale bool
eps float32
ropeBase float32
ropeBaseLocal float32
ropeScale float32
slidingWindowPattern []bool
activationSparsityScale []float32
}
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
}
func (o *TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o *TextOptions) isLocal(i int) bool {
return o.slidingWindowPattern[i]
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
TextLayers: make([]TextLayer, c.Uint("block_count")),
TextOptions: TextOptions{
hiddenLayers: int(c.Uint("block_count")),
hiddenSize: int(c.Uint("embedding_length")),
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
altupActiveIndex: int(c.Uint("altup.active_idx")),
altupInputs: int(c.Uint("altup.num_inputs")),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.freq_scale", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"),
},
}
}

View File

@@ -2,7 +2,6 @@ package llama
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
@@ -34,14 +33,6 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer
if c.String("tokenizer.ggml.model") == "llama" {
return nil, fmt.Errorf("unsupported tokenizer: llama")
}
// Best effort detection of library/deepseek-coder model(s) which are incompatible
if c.String("general.name") == "deepseek-ai" {
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),

View File

@@ -3,7 +3,6 @@ package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

View File

@@ -2,9 +2,7 @@ package qwen2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -128,14 +126,6 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
}
func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer
if c.String("tokenizer.ggml.model") == "llama" {
return nil, fmt.Errorf("unsupported tokenizer: llama")
}
// detect library/qwen model(s) which are incompatible
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(

View File

@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}

View File

@@ -1,16 +0,0 @@
package model
import "testing"
func TestVocabulary_SpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
}
specialVocab := vocab.SpecialVocabulary()
if len(specialVocab) != 4 {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}

View File

@@ -423,7 +423,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
types := []string{"jpeg", "jpg", "png", "webp"}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"

View File

@@ -27,6 +27,7 @@ function checkEnv() {
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
}
# Locate CUDA versions
# Note: this assumes every version found will be built
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
if ($cudaList.length -eq 0) {
$d=(get-command -ea 'silentlycontinue' nvcc).path
@@ -93,6 +94,19 @@ function buildOllama() {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
write-host "Building CUDA v11 backend libraries"
# Note: cuda v11 requires msvc 2019 so force the older generator
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
@@ -113,17 +127,12 @@ function buildOllama() {
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
$env:HIP_PLATFORM="amd"
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
& cmake --fresh --preset "ROCm 6" -G Ninja `
-DCMAKE_C_COMPILER=clang `
-DCMAKE_CXX_COMPILER=clang++ `
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
--install-prefix $script:DIST_DIR
& cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
& cmake --build --preset "ROCm 6" --config Release --parallel $script:JOBS
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}

View File

@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=GOFLAGS \
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
--build-arg=CUDA_V11_ARCHITECTURES \
--build-arg=CUDA_V12_ARCHITECTURES \
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
--build-arg=OLLAMA_FAST_BUILD \

115
server/cache/capabilities.go vendored Normal file
View File

@@ -0,0 +1,115 @@
package cache
import (
"fmt"
"log/slog"
"os"
"slices"
"sync"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
)
// cacheEntry stores capabilities and the modification time of the model file
type cacheEntry struct {
capabilities []model.Capability
modTime time.Time
}
// ggufCapabilities is a cache for gguf model capabilities
var ggufCapabilities = &sync.Map{}
// ModelInfo contains the minimal information needed to determine capabilities
type ModelInfo struct {
ModelPath string
ProjectorPaths []string
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func Capabilities(info ModelInfo) []model.Capability {
capabilities, err := ggufCapabilties(info.ModelPath)
if err != nil {
slog.Error("could not determine gguf capabilities", "error", err)
}
if info.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(info.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(info.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(info.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(info.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
// Get file info to check modification time
fileInfo, err := os.Stat(modelPath)
if err != nil {
return nil, err
}
currentModTime := fileInfo.ModTime()
// Check if we have a cached entry
if cached, ok := ggufCapabilities.Load(modelPath); ok {
entry := cached.(cacheEntry)
// If the file hasn't been modified since we cached it, return the cached capabilities
if entry.modTime.Equal(currentModTime) {
return entry.capabilities, nil
}
}
// If not cached or file was modified, read the model file to determine capabilities
capabilities := []model.Capability{}
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
f, err := ggml.Decode(r, 1024)
if err != nil {
return nil, err
}
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
// Cache the capabilities with the modification time
ggufCapabilities.Store(modelPath, cacheEntry{
capabilities: capabilities,
modTime: currentModTime,
})
return capabilities, nil
}

211
server/cache/capabilities_test.go vendored Normal file
View File

@@ -0,0 +1,211 @@
package cache
import (
"bytes"
"maps"
"os"
"slices"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
func testGGUF(tb testing.TB, customKV ggml.KV) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{}
maps.Copy(kv, customKV)
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{1, 1},
WriterTo: bytes.NewBuffer(make([]byte, 4)),
},
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestCapabilities(t *testing.T) {
ggufCapabilities.Range(func(key, value any) bool {
ggufCapabilities.Delete(key)
return true
})
// Create test model paths
completionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
})
visionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
})
embeddingModelPath := testGGUF(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
})
// Create templates
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testCases := []struct {
name string
model ModelInfo
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability from gguf",
model: ModelInfo{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision capability from projector",
model: ModelInfo{
ModelPath: completionModelPath,
ProjectorPaths: []string{"/path/to/projector"},
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: ModelInfo{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: ModelInfo{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// First call - should read from file
caps := Capabilities(tc.model)
slices.Sort(caps)
slices.Sort(tc.expectedCaps)
if !slices.Equal(caps, tc.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
}
// Verify caching for models that read from GGUF
if tc.model.ModelPath != "" {
// Check that entry is cached
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
if !ok {
t.Error("Expected capabilities to be cached")
}
// Second call - should use cache
caps2 := Capabilities(tc.model)
slices.Sort(caps2)
if !slices.Equal(caps, caps2) {
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
}
}
})
}
// Test cache invalidation on file modification
t.Run("cache invalidation", func(t *testing.T) {
// Use completion model for this test
info := ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
}
// Get initial cached entry
cached, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Fatal("Expected model to be cached from previous tests")
}
entry := cached.(cacheEntry)
// Modify the file's timestamp to the future
future := time.Now().Add(time.Hour)
err := os.Chtimes(completionModelPath, future, future)
if err != nil {
t.Fatalf("Failed to update file timestamp: %v", err)
}
// Call should re-read from file due to changed modtime
caps := Capabilities(info)
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
}
// Check that cache was updated with new modtime
cached2, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Error("Expected capabilities to be cached after re-read")
}
entry2 := cached2.(cacheEntry)
if entry2.modTime.Equal(entry.modTime) {
t.Error("Expected cache entry to have updated modTime")
}
})
}

View File

@@ -23,10 +23,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -68,60 +67,14 @@ type Model struct {
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
f, err := gguf.Open(m.ModelPath)
if err == nil {
defer f.Close()
if f.KeyValue("pooling_type").Valid() {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
// If no embedding is specified, we assume the model supports completion
capabilities = append(capabilities, model.CapabilityCompletion)
}
if f.KeyValue("vision.block_count").Valid() {
capabilities = append(capabilities, model.CapabilityVision)
}
} else {
slog.Error("couldn't open model file", "error", err)
}
if m.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(m.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(want ...model.Capability) error {
available := m.Capabilities()
available := cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
ProjectorPaths: m.ProjectorPaths,
Template: m.Template,
})
var errs []error
// Map capabilities to their corresponding error

View File

@@ -9,131 +9,6 @@ import (
"github.com/ollama/ollama/types/model"
)
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create simple model file for tests that don't depend on GGUF content
completionModelPath, _ := createBinFile(t, ggml.KV{

View File

@@ -59,7 +59,7 @@ type DiskCache struct {
testHookBeforeFinalWrite func(f *os.File)
}
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}

View File

@@ -231,8 +231,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// do not quantize relative position bias (T5)
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
newType := fsggml.TensorType(t.Kind)
if quantize {
// get more optimal quantization type based on the tensor shape, layer, etc.

View File

@@ -28,13 +28,13 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -56,8 +56,6 @@ var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
perms *auth.APIPermissions
}
func init() {
@@ -72,38 +70,6 @@ func init() {
gin.SetMode(mode)
}
func loggedFormatter(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
username := "default"
if userVal, exists := param.Keys["username"]; exists {
if name, ok := userVal.(string); ok {
username = name
}
}
return fmt.Sprintf(
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
username,
methodColor, param.Method, resetColor,
param.Path,
param.ErrorMessage,
)
}
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
@@ -854,13 +820,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
Template: m.Template,
ProjectorPaths: m.ProjectorPaths,
}),
ModifiedAt: manifest.fi.ModTime(),
}
var params []string
@@ -877,11 +847,8 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
resp.Parameters = strings.Join(params, "\n")
if len(req.Options) > 0 {
if m.Options == nil {
m.Options = make(map[string]any)
}
for k, v := range req.Options {
for k, v := range req.Options {
if _, ok := req.Options[k]; ok {
m.Options[k] = v
}
}
@@ -1146,43 +1113,6 @@ func allowedHost(host string) bool {
return false
}
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
return func(c *gin.Context) {
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
c.Next()
return
}
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
if err != nil {
slog.Error("authentication error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
c.Set("username", name)
if err != nil {
slog.Error("authorization error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
if !authorized {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.Next()
}
}
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
@@ -1249,13 +1179,10 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
}
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.New()
r := gin.Default()
r.HandleMethodNotAllowed = true
r.Use(
gin.LoggerWithFormatter(loggedFormatter),
gin.Recovery(),
cors.New(corsConfig),
allowedEndpointsMiddleware(s.perms),
allowedHostsMiddleware(s.addr),
)
@@ -1265,7 +1192,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.HEAD("/api/tags", s.ListHandler)
@@ -1297,7 +1224,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// wrap old with new
rs := &registry.Local{
Client: rc,
Logger: slog.Default(),
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
Prune: PruneLayers,
@@ -1342,12 +1269,6 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
if envconfig.UseAuth() {
perms := auth.NewAPIPermissions()
perms.ReloadIfNeeded()
s.perms = perms
}
var rc *ollama.Registry
if useClient2 {
var err error
@@ -1488,9 +1409,6 @@ func (s *Server) PsHandler(c *gin.Context) {
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
if v.Options != nil {
mr.ContextLength = v.Options.NumCtx / v.numParallel
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
// calculate the time w/ the sessionDuration instead.

View File

@@ -16,7 +16,6 @@ import (
"os"
"path/filepath"
"reflect"
"slices"
"sort"
"strings"
"testing"
@@ -83,6 +82,19 @@ func createTestFile(t *testing.T, name string) (string, string) {
return f.Name(), digest
}
// equalStringSlices checks if two slices of strings are equal.
func equalStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
@@ -435,7 +447,7 @@ func TestRoutes(t *testing.T) {
"stop \"foo\"",
"top_p 0.9",
}
if !slices.Equal(params, expectedParams) {
if !equalStringSlices(params, expectedParams) {
t.Errorf("expected parameters %v, got %v", expectedParams, params)
}
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)

View File

@@ -57,7 +57,9 @@ type Scheduler struct {
var defaultModelsPerGPU = 3
// Default automatic value for parallel setting
var defaultParallel = 1
// Model will still need to fit in VRAM. If this setting won't fit
// we'll back off down to 1 to try to get it to fit
var defaultParallel = 2
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
@@ -189,7 +191,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
if err != nil {
pending.errCh <- err
break

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"io"
"maps"
"math"
"slices"
"strings"
@@ -15,6 +14,7 @@ import (
"text/template/parse"
"github.com/agnivade/levenshtein"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api"
)
@@ -157,7 +157,9 @@ func (t *Template) Vars() []string {
set[strings.ToLower(n)] = struct{}{}
}
return slices.Sorted(maps.Keys(set))
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
type Values struct {
@@ -308,23 +310,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message (except for tool messages which preserve individual metadata).
// collate also collects and returns all system messages.
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
// todo(parthsareen): revisit for contextual image support
func collate(msgs []api.Message) (string, []*api.Message) {
var system []string
var collated []*api.Message
for i := range msgs {
if msgs[i].Role == "system" {
system = append(system, msgs[i].Content)
msg := msgs[i]
if msg.Role == "system" {
system = append(system, msg.Content)
}
// merges consecutive messages of the same role into a single message (except for tool messages)
if len(collated) > 0 && collated[len(collated)-1].Role == msgs[i].Role && msgs[i].Role != "tool" {
collated[len(collated)-1].Content += "\n\n" + msgs[i].Content
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msgs[i])
collated = append(collated, &msg)
}
}

View File

@@ -163,12 +163,10 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- else if eq .Role "tool" }}TOOL:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
@@ -378,99 +376,3 @@ func TestExecuteWithSuffix(t *testing.T) {
})
}
}
func TestCollate(t *testing.T) {
cases := []struct {
name string
msgs []api.Message
expected []*api.Message
system string
}{
{
name: "consecutive user messages are merged",
msgs: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "user", Content: "How are you?"},
},
expected: []*api.Message{
{Role: "user", Content: "Hello\n\nHow are you?"},
},
system: "",
},
{
name: "consecutive tool messages are NOT merged",
msgs: []api.Message{
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
expected: []*api.Message{
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
system: "",
},
{
name: "tool messages preserve all fields",
msgs: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
expected: []*api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
system: "",
},
{
name: "mixed messages with system",
msgs: []api.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
{Role: "user", Content: "Thanks"},
},
expected: []*api.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
{Role: "user", Content: "Thanks"},
},
system: "You are helpful",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
system, collated := collate(tt.msgs)
if diff := cmp.Diff(system, tt.system); diff != "" {
t.Errorf("system mismatch (-got +want):\n%s", diff)
}
// Compare the messages
if len(collated) != len(tt.expected) {
t.Errorf("expected %d messages, got %d", len(tt.expected), len(collated))
return
}
for i := range collated {
if collated[i].Role != tt.expected[i].Role {
t.Errorf("message %d role mismatch: got %q, want %q", i, collated[i].Role, tt.expected[i].Role)
}
if collated[i].Content != tt.expected[i].Content {
t.Errorf("message %d content mismatch: got %q, want %q", i, collated[i].Content, tt.expected[i].Content)
}
if collated[i].ToolName != tt.expected[i].ToolName {
t.Errorf("message %d tool name mismatch: got %q, want %q", i, collated[i].ToolName, tt.expected[i].ToolName)
}
}
})
}
}

View File

@@ -18,8 +18,9 @@ const (
)
type Parser struct {
tag string
tools []api.Tool
tag string
names []string
properties []string
state toolsState
buffer []byte
@@ -33,10 +34,15 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
}
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
return &Parser{
tag: tag,
tools: tools,
var p Parser
for _, t := range tools {
p.names = append(p.names, t.Function.Name)
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
}
}
p.tag = tag
return &p
}
// Add processes a string input to parse tool calls and content that
@@ -115,24 +121,36 @@ func (p *Parser) findTag() (int, bool) {
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
tool, end := findTool(p.tools, p.buffer)
if tool == nil {
var name string
var args map[string]any
var end int = len(p.buffer)
// find tool name
var i int
for _, n := range p.names {
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
name = n
end = i + len(n)
}
}
}
if name == "" {
return nil
}
var args map[string]any
if found, i := findArguments(p.buffer); found == nil {
if args, i = p.findArguments(); args == nil {
return nil
} else {
args = found
if i > end {
end = i
}
}
if i > end {
end = i
}
tc := &api.ToolCall{
Function: api.ToolCallFunction{
Name: tool.Function.Name,
Name: name,
Arguments: args,
Index: p.n,
},
@@ -143,144 +161,83 @@ func (p *Parser) parseToolCall() *api.ToolCall {
return tc
}
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
if len(buf) == 0 {
return nil, 0
}
// check if buffer ends with a partial tool name
// this prevents matching "get" when seeing "get_weather"
var longest string
for _, t := range tools {
if len(t.Function.Name) > len(longest) {
longest = t.Function.Name
}
}
// Only check up to longest characters from the end
for i := 1; i <= min(len(buf), len(longest)); i++ {
tail := buf[len(buf)-i:]
for _, t := range tools {
name := []byte(t.Function.Name)
if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
return nil, 0
}
}
}
// find first occurrence of the longest tool name
var found *api.Tool
start := -1
end := -1
for i := range tools {
name := []byte(tools[i].Function.Name)
pos := bytes.Index(buf, name)
if pos == -1 {
continue
}
// Skip if we have a better match already
if start != -1 {
if pos > start {
continue
}
if pos == start && len(name) <= len(found.Function.Name) {
continue
}
}
found = &tools[i]
start = pos
end = pos + len(name)
}
if found != nil {
return found, end
}
return nil, 0
}
// findArguments returns the first object that appears to be
// arguments for the provided tool in the provided buffer,
// returning nil if no arguments are found and the end position
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
func findArguments(buffer []byte) (map[string]any, int) {
if len(buffer) == 0 {
// arguments and the position where the arguments end, returning nil and 0 if
// an invalid JSON object or non-arguments object is found first
func (p *Parser) findArguments() (map[string]any, int) {
if len(p.buffer) == 0 {
return nil, 0
}
var braces int
var start int = -1
var end int
var object []byte
for i, c := range buffer {
// find any outer json object
for i, c := range p.buffer {
if c == '{' {
if braces == 0 {
braces++
if start == -1 {
start = i
}
braces++
} else if c == '}' && braces > 0 {
}
if c == '}' {
braces--
if braces == 0 && start != -1 {
object := buffer[start : i+1]
var data map[string]any
if err := json.Unmarshal(object, &data); err != nil {
start = -1
continue
}
var findObject func(obj map[string]any) (map[string]any, bool)
findObject = func(obj map[string]any) (map[string]any, bool) {
if _, hasName := obj["name"]; hasName {
if args, ok := obj["arguments"].(map[string]any); ok {
return args, true
}
if args, ok := obj["parameters"].(map[string]any); ok {
return args, true
}
return nil, true
}
for _, v := range obj {
switch child := v.(type) {
case map[string]any:
if result, found := findObject(child); found {
return result, true
}
case []any:
for _, item := range child {
if childObj, ok := item.(map[string]any); ok {
if result, found := findObject(childObj); found {
return result, true
}
}
}
}
}
return nil, false
}
if args, found := findObject(data); found {
return args, i
}
return data, i
end = i + 1
object = p.buffer[start:end]
break
}
}
}
if braces > 0 {
return nil, 0
}
var data map[string]any
// not valid json
if err := json.Unmarshal(object, &data); err != nil {
return nil, 0
}
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch v := obj.(type) {
case map[string]any:
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
}
}
for _, value := range v {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range v {
if result := find(item); result != nil {
return result
}
}
}
return nil
}
result := find(data)
if result != nil {
return result, end
}
return nil, 0
}

View File

@@ -52,8 +52,7 @@ func TestParser(t *testing.T) {
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"city"},
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
@@ -105,88 +104,6 @@ func TestParser(t *testing.T) {
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello_world",
Description: "Say hello world",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_address",
Description: "Get the address of a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the address for",
},
},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "add",
Description: "Add two numbers",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"a": {
Type: api.PropertyType{"string"},
Description: "The first number to add",
},
"b": {
Type: api.PropertyType{"string"},
Description: "The second number to add",
},
},
},
},
},
}
tests := []struct {
@@ -227,21 +144,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "empty args",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "text before tool call",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
@@ -259,13 +161,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "qwen no args with text",
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
content: "Let me say hello to the user. I'll use the say_hello tool. ",
tmpl: qwen,
calls: nil,
},
{
name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
@@ -294,7 +189,7 @@ func TestParser(t *testing.T) {
},
},
{
name: "qwen two tool calls",
name: "two tool calls",
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
content: "Okay, let's call both tools! ",
tmpl: qwen,
@@ -320,55 +215,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "empty args followed by args",
inputs: []string{`Let me say hello and check the weather. <tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call>`},
content: "Let me say hello and check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
},
},
},
},
{
name: "qwen empty followed by args",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_conditions", "arguments": {}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
content: "Let me check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "deepseek",
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
@@ -534,30 +380,6 @@ func TestParser(t *testing.T) {
},
{
name: "list partial",
inputs: []string{
"[{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list invalid",
inputs: []string{
"[",
"{",
@@ -571,33 +393,6 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "list trailing ]",
inputs: []string{
"[",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
"]",
"]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list not a tool call",
inputs: []string{
@@ -609,193 +404,6 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "tool name with collision",
inputs: []string{
"<tool_call>",
"{",
"\"name\": \"say_hello",
"_world\",",
"\"arguments\": {}}",
"}",
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "tool name with collision multiple",
inputs: []string{
"<tool_call>",
"{",
"\"name\": \"say_hello",
"_world\",",
"\"arguments\": {}}",
"</tool_call>",
"<tool_call>",
"{",
"\"name\": \"say_hello",
"\",",
"\"arguments\": {}}",
"</tool_call>",
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "tool name with collision non streaming",
inputs: []string{
`<tool_call>{"name": "say_hello`,
},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "tool name with collision non streaming multiple",
inputs: []string{
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "say_hello_world", "arguments": {}}`,
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "tool name with collision non streaming shorter",
inputs: []string{
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call>`,
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "tool name with collision non streaming longer",
inputs: []string{
`<tool_call>{"name": "say_hello_world", "arguments": {}}</tool_call>`,
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "tool name with substring of another",
inputs: []string{
"{",
"\"name\": \"get_address\",",
"\"arguments\": {",
"\"location\": \"London\"",
"}",
"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_address",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
},
{
name: "tool name with substring of another",
inputs: []string{
`<tool_call>{"name": "get_address", "arguments": {"location": "London"}}</tool_call>`,
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_address",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
},
{
name: "args before name",
inputs: []string{
`<tool_call>{"arguments": {"a": "5", "b": "10"}, "name": "add"}</tool_call>`,
},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "add",
Arguments: api.ToolCallFunctionArguments{
"a": "5",
"b": "10",
},
},
},
},
},
}
for _, tt := range tests {
@@ -1142,7 +750,7 @@ func TestFindArguments(t *testing.T) {
},
{
name: "valid arguments in array",
buffer: []byte(`[{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
@@ -1158,14 +766,14 @@ func TestFindArguments(t *testing.T) {
},
{
name: "one arg",
buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`),
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
want: map[string]any{
"location": "San Francisco, CA",
},
},
{
name: "two args",
buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
want: map[string]any{
"location": "San Francisco, CA",
"format": "fahrenheit",
@@ -1173,14 +781,7 @@ func TestFindArguments(t *testing.T) {
},
{
name: "deepseek",
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{
"location": "Tokyo",
},
},
{
name: "deepseek",
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{
"location": "Tokyo",
},
@@ -1188,8 +789,13 @@ func TestFindArguments(t *testing.T) {
}
for _, tt := range tests {
parser := &Parser{
buffer: tt.buffer,
properties: []string{"format", "location"},
}
t.Run(tt.name, func(t *testing.T) {
got, _ := findArguments(tt.buffer)
got, _ := parser.findArguments()
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)