Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
d5eae8248d runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.
2025-06-13 16:26:57 -07:00
69 changed files with 1122 additions and 132334 deletions

View File

@@ -54,6 +54,48 @@ jobs:
name: build-${{ matrix.os }}-${{ matrix.arch }}
path: dist/*
darwin-sign:
runs-on: macos-13
environment: release
needs: darwin-build
steps:
- uses: actions/checkout@v4
- run: |
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
security create-keychain -p password build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p password build.keychain
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
security set-keychain-settings -lut 3600 build.keychain
env:
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
- uses: actions/download-artifact@v4
with:
name: build-darwin-amd64
path: dist/darwin-amd64
- uses: actions/download-artifact@v4
with:
name: build-darwin-arm64
path: dist/darwin-arm64
- run: |
export VERSION=${GITHUB_REF_NAME#v}
./scripts/build_darwin.sh sign macapp
env:
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
APPLE_ID: ${{ vars.APPLE_ID }}
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
- uses: actions/upload-artifact@v4
with:
name: dist-darwin
path: |
dist/Ollama-darwin.zip
dist/ollama-darwin.tgz
windows-depends:
strategy:
matrix:
@@ -61,18 +103,21 @@ jobs:
arch: [amd64]
preset: ['CPU']
include:
- os: windows
arch: amd64
preset: 'CUDA 11'
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
cuda-version: '11.3'
- os: windows
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8'
flags: ''
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -115,9 +160,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -136,9 +178,9 @@ jobs:
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
@@ -188,11 +230,61 @@ jobs:
go-version-file: go.mod
- run: |
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
- if: matrix.arch == 'arm64'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
- run: |
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
& .\scripts\build_windows.ps1 buildApp
env:
VCToolsRedistDir: stub
- uses: actions/upload-artifact@v4
with:
name: build-${{ matrix.os }}-${{ matrix.arch }}
path: |
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
windows-sign:
runs-on: windows-2022
environment: release
needs: [windows-depends, windows-build]
steps:
- uses: actions/checkout@v4
- uses: google-github-actions/auth@v2
with:
project_id: ollama
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
- run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
- uses: actions/download-artifact@v4
with:
pattern: build-windows-*
path: dist\
merge-multiple: true
- uses: actions/download-artifact@v4
with:
pattern: depends-windows-amd64-*
path: dist\windows-amd64\
merge-multiple: true
- run: |
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: |
dist\OllamaSetup.exe
dist\ollama-windows-*.zip
linux-build:
strategy:
@@ -225,26 +317,21 @@ jobs:
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
esac
done
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
echo "Manifests"
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
echo $ARCHIVE
cat $ARCHIVE
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
@@ -298,8 +385,8 @@ jobs:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
@@ -331,7 +418,7 @@ jobs:
latest=false
suffix=${{ matrix.suffix }}
images: |
${{ vars.DOCKER_REPO }}
ollama/ollama
tags: |
type=ref,enable=true,priority=600,prefix=pr-,event=pr
type=semver,pattern={{version}}
@@ -341,24 +428,56 @@ jobs:
path: ${{ runner.temp }}
merge-multiple: true
- run: |
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }}
# Trigger downstream release process
trigger:
runs-on: ubuntu-latest
environment: release
needs: [darwin-build, windows-build, windows-depends, linux-build]
needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]
runs-on: linux
environment: release
permissions:
contents: write
env:
GH_TOKEN: ${{ github.token }}
steps:
- uses: actions/checkout@v4
- name: Create or update Release for tag
- uses: actions/download-artifact@v4
with:
name: dist-darwin
path: dist
- uses: actions/download-artifact@v4
with:
name: dist-windows
path: dist
- uses: actions/download-artifact@v4
with:
pattern: dist-linux-*
path: dist
merge-multiple: true
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
- name: Create or update Release
run: |
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
echo "Looking for existing release for ${RELEASE_VERSION}"
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
if [ -n "$OLD_TAG" ]; then
@@ -372,12 +491,5 @@ jobs:
--generate-notes \
--prerelease
fi
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
gh release upload ${GITHUB_REF_NAME} dist/* --clobber

View File

@@ -36,7 +36,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
linux:
needs: [changes]
@@ -46,7 +46,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
@@ -78,11 +78,11 @@ jobs:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows
steps:
- run: |
@@ -102,7 +102,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
@@ -120,9 +120,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -136,8 +133,8 @@ jobs:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}"
env:

View File

@@ -78,13 +78,14 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
)
endif()
@@ -115,11 +116,7 @@ if(CMAKE_HIP_COMPILER)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCY_SET rocm
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
)
install(RUNTIME_DEPENDENCY_SET rocm
RUNTIME_DEPENDENCIES
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*"

View File

@@ -17,12 +17,20 @@
"name": "CUDA",
"inherits": [ "Default" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{
@@ -50,7 +58,6 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
}
@@ -71,6 +78,11 @@
"configurePreset": "CUDA",
"targets": [ "ggml-cuda" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],

View File

@@ -7,13 +7,12 @@ ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -39,6 +38,15 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.3
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
@@ -90,21 +98,23 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
FROM ubuntu:20.04
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \

View File

@@ -1,6 +1,6 @@
<div align="center">
  <a href="https://ollama.com">
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a>
</div>
@@ -10,7 +10,7 @@ Get up and running with large language models.
### macOS
[Download](https://ollama.com/download/Ollama.dmg)
[Download](https://ollama.com/download/Ollama-darwin.zip)
### Windows
@@ -407,9 +407,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
### Cloud
@@ -455,7 +452,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
### Apple Vision Pro
@@ -594,7 +590,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
### Supported backends

View File

@@ -143,7 +143,6 @@ type Message struct {
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -190,8 +190,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":

View File

@@ -1,165 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"gonum.org/v1/gonum/stat/distuv"
)
type gemma3nModel struct {
ModelParameters
TextModel struct {
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
AltupActiveIdx uint32 `json:"altup_active_idx"`
AltupCoefClip float32 `json:"altup_coef_clip"`
AltupCorrectScale bool `json:"altup_correct_scale"`
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
AltupNumInputs uint32 `json:"altup_num_inputs"`
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
IntermediateSize uint32 `json:"intermediate_size"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
} `json:"text_config"`
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
norm := distuv.Normal{Mu: 0, Sigma: 1}
for _, v := range m.TextModel.ActivationSparsityPattern {
if !yield(float32(norm.Quantile(float64(v)))) {
break
}
}
})
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, t := range m.TextModel.LayerTypes {
if !yield(t == "sliding_attention") {
break
}
}
})
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
return kv
}
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
out, ts := mergeTensors(ts,
merge{"altup_proj.*.weight", "altup_proj.weight"},
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
)
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "audio_tower"),
strings.Contains(t.Name(), "embed_audio"),
strings.Contains(t.Name(), "vision_tower"),
strings.Contains(t.Name(), "embed_vision"):
// TODO: handle audio and vision towers
continue
case strings.Contains(t.Name(), "altup_predict_coef"),
strings.Contains(t.Name(), "altup_correct_coef"):
if m.TextModel.AltupCoefClip > 0 {
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
if err != nil {
return nil, err
}
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gemma3nModel) Replacements() []string {
return []string{
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"post_feedforward_layernorm", "post_ffw_norm",
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
"altup.", "altup_",
"modality_router", "router",
"prediction_coefs", "predict_coef",
"correction_coefs", "correct_coef",
"correct_output_scale", "correct_scale.weight",
"laurel.", "laurel_",
"linear_left", "l",
"linear_right", "r",
"post_laurel_norm", "post_norm",
}
}

View File

@@ -2,6 +2,9 @@ package convert
import (
"fmt"
"io"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
@@ -27,38 +30,65 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
}
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
merges := make([]merge, 0, p.NumHiddenLayers*6)
for i := range p.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w1.bias", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
}, merge{
fmt.Sprintf("blk.%d.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w2.bias", i),
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
}, merge{
fmt.Sprintf("blk.%d.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.*.w3.bias", i),
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
"w2", "ffn_down_exps",
"w3", "ffn_up_exps",
}
for i := range p.NumLocalExperts {
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
}
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
namer := strings.NewReplacer(oldnew...)
experts := make(map[string]experts)
// merge experts into a single tensor while removing them from ts
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
return false
}
name := namer.Replace(t.Name())
experts[name] = append(experts[name], t)
return true
})
var out []*ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, &ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
WriterTo: e,
})
}
out, ts := mergeTensors(ts, merges...)
return append(out, p.llamaModel.Tensors(ts)...)
}
func (p *mixtralModel) Replacements() []string {
return append(
p.llamaModel.Replacements(),
"model.layers", "blk",
"block_sparse_moe.gate", "ffn_gate_inp",
"block_sparse_moe.experts.", ".",
)
}
type experts []Tensor
func (e experts) WriteTo(w io.Writer) (int64, error) {
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
for _, t := range e {
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
// this accomplishes the same thing by writing each expert tensor in sequence
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -2,9 +2,7 @@ package convert
import (
"cmp"
"io"
"iter"
"path"
"slices"
"strings"
@@ -76,54 +74,3 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
}
}
}
type merge struct {
pattern, name string
}
// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
var matched []Tensor
for i := range merges {
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
matched, _ := path.Match(merges[i].pattern, t.Name())
return matched
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,
Kind: matched[0].Kind(),
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
WriterTo: mergeGroup(matched),
})
}
}
return out, unmatched
}
// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
for _, e := range s {
if fn(e) {
matched = append(matched, e)
} else {
unmatched = append(unmatched, e)
}
}
return matched, unmatched
}
type mergeGroup []Tensor
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
for _, t := range g {
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -9,8 +9,6 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
)
@@ -304,99 +302,3 @@ func TestSplitDim(t *testing.T) {
}
})
}
func TestMerge(t *testing.T) {
unmatched := []Tensor{
&fakeTensor{
name: "a.0.b",
shape: []uint64{5, 2},
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
},
&fakeTensor{
name: "a.1.b",
shape: []uint64{5, 2},
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
},
&fakeTensor{
name: "c.0.d",
shape: []uint64{5, 2},
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
},
&fakeTensor{
name: "c.1.d",
shape: []uint64{5, 2},
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
},
&fakeTensor{
name: "e.0.f",
shape: []uint64{5, 2},
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
},
}
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
for i := range n {
got := matched[i]
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
t.Errorf("unexpected (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := got.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, 20)
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
offset := 10 + (i * 20)
want := make([]float32, 20)
for j := range 20 {
want[j] = float32(offset + j)
}
if diff := cmp.Diff(want, f32s); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
}
t.Run("single merge", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
if len(unmatched) != 3 {
t.Error("expected 3 remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
checkMatched(t, 1, matched)
})
t.Run("multiple merges", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
if len(unmatched) != 1 {
t.Error("expected 1 remaining tensors, got", len(unmatched))
}
if len(matched) != 2 {
t.Error("expected 2 merged tensor, got", len(matched))
}
checkMatched(t, 2, matched)
})
t.Run("no match", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
if len(unmatched) != 5 {
t.Error("expected 5 remaining tensors, got", len(unmatched))
}
if len(matched) != 0 {
t.Error("expected no merged tensors, got", len(matched))
}
})
}

View File

@@ -3,7 +3,6 @@
package discover
import (
"fmt"
"log/slog"
"os"
"regexp"
@@ -56,13 +55,10 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
}
return "sbsa"
}
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11"
}
return "v12"

View File

@@ -12,7 +12,7 @@ import (
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {

View File

@@ -500,7 +500,6 @@ The `message` object has the following fields:
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
Advanced parameters (optional):
@@ -509,21 +508,13 @@ Advanced parameters (optional):
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Tool calling
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
### Examples
#### Chat request (Streaming)
#### Chat Request (Streaming)
##### Request
@@ -578,88 +569,6 @@ Final response:
}
```
#### Chat request (Streaming with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": true
}'
```
##### Response
A stream of JSON objects is returned:
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:22:19.184789Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done": false
}
```
Final response:
```json
{
"model":"llama3.2",
"created_at":"2025-07-07T20:22:19.19314Z",
"message": {
"role": "assistant",
"content": ""
},
"done_reason": "stop",
"done": true,
"total_duration": 182242375,
"load_duration": 41295167,
"prompt_eval_count": 169,
"prompt_eval_duration": 24573166,
"eval_count": 15,
"eval_duration": 115959084
}
```
#### Chat request (No streaming)
##### Request
@@ -697,74 +606,6 @@ curl http://localhost:11434/api/chat -d '{
}
```
#### Chat request (No streaming, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": false
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:32:53.844124Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 3244883583,
"load_duration": 2969184542,
"prompt_eval_count": 169,
"prompt_eval_duration": 141656333,
"eval_count": 18,
"eval_duration": 133293625
}
```
#### Chat request (Structured outputs)
##### Request
@@ -871,87 +712,6 @@ Final response:
}
```
#### Chat request (With history, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in Toronto?"
},
// the message from the model appended to history
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_temperature",
"arguments": {
"city": "Toronto"
}
},
}
]
},
// the tool call result appended to history
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_temperature",
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
]
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:43:37.688511Z",
"message": {
"role": "assistant",
"content": "The current temperature in Toronto is 11°C."
},
"done_reason": "stop",
"done": true,
"total_duration": 890771750,
"load_duration": 707634750,
"prompt_eval_count": 94,
"prompt_eval_duration": 91703208,
"eval_count": 11,
"eval_duration": 90282125
}
```
#### Chat request (with images)
##### Request

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -1,14 +1,12 @@
# GPU
## Nvidia
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
Ollama supports Nvidia GPUs with compute capability 5.0+.
Check your compute compatibility to see if your card is supported:
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |

View File

@@ -43,7 +43,7 @@ Ollama includes multiple LLM libraries compiled for different GPUs and CPU vecto
In the server log, you will see a message that looks something like this (varies from release to release):
```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
```
**Experimental LLM Library Override**

View File

@@ -10,5 +10,4 @@ type Config interface {
Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
}

View File

@@ -34,8 +34,7 @@ func (kv KV) Kind() string {
}
func (kv KV) ParameterCount() uint64 {
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
return val
return keyValue(kv, "general.parameter_count", uint64(0))
}
func (kv KV) FileType() FileType {
@@ -54,27 +53,16 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
func (kv KV) HeadCount() uint64 {
return uint64(kv.Uint("attention.head_count"))
}
func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
func (kv KV) HeadCountKV() uint64 {
return uint64(kv.Uint("attention.head_count_kv", 1))
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
func (kv KV) HeadCountKVMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
}
func (kv KV) EmbeddingHeadCountMax() uint64 {
if heads := kv.HeadCountMin(); heads > 0 {
func (kv KV) EmbeddingHeadCount() uint64 {
if heads := kv.HeadCount(); heads > 0 {
return kv.EmbeddingLength() / heads
}
@@ -82,11 +70,15 @@ func (kv KV) EmbeddingHeadCountMax() uint64 {
}
func (kv KV) EmbeddingHeadCountK() uint64 {
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) EmbeddingHeadCountV() uint64 {
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) GQA() uint64 {
return kv.HeadCount() / kv.HeadCountKV()
}
func (kv KV) ContextLength() uint64 {
@@ -98,83 +90,40 @@ func (kv KV) ChatTemplate() string {
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
return keyValue(kv, key, append(defaultValue, "")...)
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
_, max := kv.UintOrArrayValue(key, defaultValue)
return max
}
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
min, _ := kv.UintOrArrayValue(key, defaultValue)
return min
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return uint32(min), uint32(max)
}
return defaultValue, defaultValue
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
return val.values
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
return val.values
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
return val.values
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
return val.values
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
return val.values
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"gemma3n",
"mistral3",
"llama4",
"mllama",
@@ -194,17 +143,17 @@ type arrayValueTypes interface {
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
if val, ok := kv[key]; ok {
return val.(T)
}
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
return defaultValue[0], false
slog.Debug("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0]
}
type Tensors struct {
@@ -476,11 +425,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsKV := f.KV().HeadCountKVMax()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
@@ -555,7 +504,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2", "gemma3", "gemma3n":
case "gemma", "gemma2", "gemma3":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -568,11 +517,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding*embeddingHeadsK*heads*9/16,
)
if f.KV().Architecture() == "gemma3n" {
fullOffload *= 4
partialOffload *= 4
}
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {

View File

@@ -269,33 +269,3 @@ func TestKeyValue(t *testing.T) {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}
func TestHeadCount(t *testing.T) {
valuesArray := []int32{1, 5, 3, 4}
cases := []struct {
kv KV
want uint64
}{
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
},
want: uint64(5),
},
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": uint32(3),
},
want: uint64(3),
},
}
for _, tt := range cases {
got := tt.kv.HeadCountMax()
if got != tt.want {
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
}
}
}

View File

@@ -527,17 +527,23 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
keys := slices.Collect(maps.Keys(kv))
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err
}
}
slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i > 0 && j > 0 {
if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j)
}
return cmp.Compare(a.Name, b.Name)
})
var s uint64
@@ -609,10 +615,6 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values)
case []bool:
err = writeGGUFArray(ws, ggufTypeBool, v)
case *array[bool]:
err = writeGGUFArray(ws, ggufTypeBool, v.values)
default:
return fmt.Errorf("improper type for '%s'", k)
}

View File

@@ -2,82 +2,62 @@ package ggml
import (
"bytes"
"math/rand/v2"
"os"
"strings"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
r := rand.New(rand.NewPCG(0, 0))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
w, err := os.CreateTemp(t.TempDir(), "*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, []*Tensor{
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err)
}
r.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(ff.KV(), KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(36),
}); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
if diff := cmp.Diff(ff.Tensors(), Tensors{
Offset: 336,
items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
},
}, cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -65,7 +65,7 @@ func Open(path string) (f *File, err error) {
return nil, err
}
if f.Version < 2 {
if f.Version != 3 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}

2
go.mod
View File

@@ -25,7 +25,6 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
)
require (
@@ -45,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)

View File

@@ -19,6 +19,35 @@ import (
"github.com/ollama/ollama/format"
)
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
}
)
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
@@ -39,13 +68,6 @@ func TestModelsGenerate(t *testing.T) {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {

View File

@@ -1,266 +0,0 @@
//go:build integration && perf
package integration
import (
"context"
"fmt"
"io/ioutil"
"log/slog"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
// Models that don't work reliably with the large context prompt in this test case
longContextFlakes = []string{
"granite-code:latest",
"nemotron-mini:latest",
"falcon:latest", // 2k model
"falcon2:latest", // 2k model
"minicpm-v:latest",
"qwen:latest",
"solar-pro:latest",
}
)
// Note: this test case can take a long time to run, particularly on models with
// large contexts. Run with -timeout set to a large value to get reasonable coverage
// Example usage:
//
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
longPrompt := "summarize the following: " + string(data)
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
var maxContext int
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
t.Fatalf("show failed: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
// For these tests we want to exercise a some amount of overflow on the CPU
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
slog.Info("scneario", "model", model, "max_context", maxContext)
loaded := false
defer func() {
// best effort unload once we're done with the model
if loaded {
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}
}()
// Some models don't handle the long context data well so skip them to avoid flaky test results
longContextFlake := false
for _, flake := range longContextFlakes {
if model == flake {
longContextFlake = true
break
}
}
// iterate through a few context sizes for coverage without excessive runtime
var contexts []int
keepGoing := true
if maxContext > 16384 {
contexts = []int{4096, 8192, 16384, maxContext}
} else if maxContext > 8192 {
contexts = []int{4096, 8192, maxContext}
} else if maxContext > 4096 {
contexts = []int{4096, maxContext}
} else if maxContext > 0 {
contexts = []int{maxContext}
} else {
t.Fatal("unknown max context size")
}
for _, numCtx := range contexts {
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
break
}
skipLongPrompt := false
// Workaround bug 11172 temporarily...
maxPrompt := longPrompt
// If we fill the context too full with the prompt, many models
// quickly hit context shifting and go bad.
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
maxPrompt = maxPrompt[:numCtx*2]
}
testCases := []struct {
prompt string
anyResp []string
}{
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
}
var gpuPercent int
for _, tc := range testCases {
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
continue
}
req := api.GenerateRequest{
Model: model,
Prompt: tc.prompt,
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": numCtx,
},
}
atLeastOne := false
var resp api.GenerateResponse
stream := false
req.Stream = &stream
// Avoid potentially getting stuck indefinitely
limit := 5 * time.Minute
genCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(limit),
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
)
defer cancel()
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
resp = rsp
return nil
})
if err != nil {
// Avoid excessive test runs, but don't consider a failure with massive context
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
slog.Warn("max context was taking too long, skipping", "error", err)
keepGoing = false
skipLongPrompt = true
continue
}
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
}
loaded = true
for _, expResp := range tc.anyResp {
if strings.Contains(strings.ToLower(resp.Response), expResp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
}
models, err := client.ListRunning(ctx)
if err != nil {
slog.Warn("failed to list running models", "error", err)
continue
}
if len(models.Models) > 1 {
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
}
for _, m := range models.Models {
if m.Name == model {
if m.SizeVRAM == 0 {
slog.Info("Model fully loaded into CPU")
gpuPercent = 0
keepGoing = false
skipLongPrompt = true
} else if m.SizeVRAM == m.Size {
slog.Info("Model fully loaded into GPU")
gpuPercent = 100
} else {
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
gpuPercent = int(100 - cpuPercent)
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
keepGoing = false
// Heuristic to avoid excessive test run time
if gpuPercent < 90 {
skipLongPrompt = true
}
}
}
}
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL",
"CONTEXT",
"GPU PERCENT",
"PROMPT COUNT",
"LOAD TIME",
"PROMPT EVAL TPS",
"EVAL TPS",
)
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
model,
numCtx,
gpuPercent,
resp.PromptEvalCount,
float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
)
}
}
})
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -32,48 +32,6 @@ const (
smol = "llama3.2:1b"
)
var (
started = time.Now()
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"gemma3n:e2b",
"mistral-small3.2:latest",
"deepseek-r1:1.5b",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen2.5vl:3b",
"qwen3:0.6b", // dense
"qwen3:30b", // MOE
"gemma3:1b",
"llama3.1:latest",
"llama3.2:latest",
"gemma2:latest",
"minicpm-v:latest", // arch=qwen2
"granite-code:latest", // arch=llama
}
llamaRunnerChatModels = []string{
"mistral:latest",
"falcon3:latest",
"granite3-moe:latest",
"command-r:latest",
"nemotron-mini:latest",
"phi3.5:latest",
"solar-pro:latest",
"internlm2:latest",
"codellama:latest", // arch=llama
"phi3:latest",
"falcon2:latest",
"gemma:latest",
"llama2:latest",
"nous-hermes:latest",
"orca-mini:latest",
"qwen:latest",
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
}
)
func Init() {
lifecycle.InitLogging()
}

View File

@@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
llama_model_loader::llama_model_loader(
const std::string & fname,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 3a4e72a3..db62973f 100644
index 3a4e72a3..831b68c0 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {

View File

@@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
4 files changed, 59 insertions(+), 79 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index dca22d8b..1f3a3956 100644
index c22687e4..c5948e8f 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
if (!kv_self->find_slot(ubatch)) {
@@ -41,7 +41,7 @@ index dca22d8b..1f3a3956 100644
}
ggml_backend_sched_reset(sched.get());
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {

View File

@@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
3 files changed, 192 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 955fec59..654e2f28 100644
index becdae07..7a44b6cf 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
@@ -59,7 +59,7 @@ index 955fec59..654e2f28 100644
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
@@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index d027271f..4abd01d7 100644
index 2d46176e..47383486 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
@@ -257,7 +257,7 @@ index d027271f..4abd01d7 100644
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@@ -266,7 +266,7 @@ index d027271f..4abd01d7 100644
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;

View File

@@ -1,32 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Sun, 22 Jun 2025 09:22:05 -0700
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
---
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 4e67d243..8f49f084 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent, dir_path);
- ggml_backend_load_best("cuda", silent, dir_path);
- ggml_backend_load_best("hip", silent, dir_path);
+
+ // Avoid mixed hip+cuda configurations
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
+ if (!hip_devices && !rocr_devices) {
+ ggml_backend_load_best("cuda", silent, dir_path);
+ } else {
+ ggml_backend_load_best("hip", silent, dir_path);
+ }
+
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);

View File

@@ -1,169 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Thu, 19 Jun 2025 08:05:21 +0300
Subject: [PATCH] metal : add mean kernel (#14267)
* metal : add mean kernel
ggml-ci
* cont : dedup implementation
ggml-ci
---
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
2 files changed, 67 insertions(+), 14 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index ee4f2dcb..f20f5615 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+ GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (dst->op) {
+ case GGML_OP_SUM_ROWS:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ break;
+ case GGML_OP_MEAN:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ }
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+ nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 9cfddf45..08e8d807 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -956,31 +956,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
+template <bool norm>
kernel void kernel_sum_rows(
+ constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
- constant ggml_metal_kargs_sum_rows & args,
- uint3 tpig[[thread_position_in_grid]]) {
- int64_t i3 = tpig.z;
- int64_t i2 = tpig.y;
- int64_t i1 = tpig.x;
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ int64_t i3 = tgpig.z;
+ int64_t i2 = tgpig.y;
+ int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
- float row_sum = 0;
+ float sumf = 0;
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
- row_sum += src_row[i0];
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+ sumf += src_row[i0];
}
- dst_row[0] = row_sum;
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ if (tpitg.x == 0) {
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
+ }
}
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
+
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -151,12 +151,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
if graphPartialOffload == 0 {
headsKV := f.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := f.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
graphPartialOffload = f.KV().GQA() * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload

View File

@@ -22,6 +22,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/sync/semaphore"
@@ -139,13 +140,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
gpus = discover.GetCPUInfo()
}
// Verify the requested context size is <= the model training size
trainCtx := f.KV().ContextLength()
if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 {
slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx)
opts.NumCtx = int(trainCtx) * numParallel
}
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
if len(gpus) > 1 || gpus[0].Library != "cpu" {
switch {
@@ -318,7 +312,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
params = append(params, "--mmproj", projectors[0])
}
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
// without any LD_LIBRARY_PATH flags
for {
@@ -732,10 +726,68 @@ type CompletionResponse struct {
EvalDuration time.Duration `json:"eval_duration"`
}
// unicodeBufferHandler wraps a completion response callback to handle partial UTF-8 sequences.
// This function creates a stateful closure that is NOT safe for concurrent use.
// Each completion request should create its own handler instance.
func unicodeBufferHandler(fn func(CompletionResponse)) func(CompletionResponse) {
var pendingUTF8 string
return func(resp CompletionResponse) {
if resp.Content == "" && !resp.Done {
// No content to process, just pass through
fn(resp)
return
}
// Combine any pending UTF-8 with current content
combinedContent := pendingUTF8 + resp.Content
pendingUTF8 = ""
// Check if combined content is valid UTF-8
if utf8.ValidString(combinedContent) {
// Valid UTF-8, send it
resp.Content = combinedContent
fn(resp)
} else {
// Invalid UTF-8
if resp.Done {
// This is the final response, trim incomplete UTF-8
trimmedContent := combinedContent
for !utf8.ValidString(trimmedContent) && len(trimmedContent) > 0 {
trimmedContent = trimmedContent[:len(trimmedContent)-1]
}
resp.Content = trimmedContent
fn(resp)
} else {
// Not final response, split valid and invalid parts
validPrefix := combinedContent
for !utf8.ValidString(validPrefix) && len(validPrefix) > 0 {
validPrefix = validPrefix[:len(validPrefix)-1]
}
if len(validPrefix) > 0 {
// Send valid prefix
resp.Content = validPrefix
fn(resp)
// Buffer the remainder
pendingUTF8 = combinedContent[len(validPrefix):]
} else {
// No valid prefix, buffer everything
pendingUTF8 = combinedContent
// Don't send this response
}
}
}
}
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt)
// Wrap the callback with unicode buffer handling
unicodeFn := unicodeBufferHandler(fn)
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -861,13 +913,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
if c.Content != "" {
fn(CompletionResponse{
unicodeFn(CompletionResponse{
Content: c.Content,
})
}
if c.Done {
fn(c)
unicodeFn(c)
return nil
}
}

View File

@@ -70,3 +70,152 @@ func TestLLMServerCompletionFormat(t *testing.T) {
}, nil)
checkValid(err)
}
func TestUnicodeBufferHandler(t *testing.T) {
tests := []struct {
name string
inputResponses []CompletionResponse
expectedResponses []CompletionResponse
description string
}{
{
name: "complete_unicode",
inputResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "!", Done: true},
},
expectedResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "!", Done: true},
},
description: "All responses with valid unicode should pass through unchanged",
},
{
name: "incomplete_unicode_at_end_with_done",
inputResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: string([]byte{0xF0, 0x9F}), Done: true}, // Incomplete emoji with Done=true
},
expectedResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: "", Done: true}, // Content is trimmed but response is still sent with Done=true
},
description: "When Done=true, incomplete Unicode at the end should be trimmed",
},
{
name: "split_unicode_across_responses",
inputResponses: []CompletionResponse{
{Content: "Hello " + string([]byte{0xF0, 0x9F}), Done: false}, // First part of 😀
{Content: string([]byte{0x98, 0x80}) + " world!", Done: true}, // Second part of 😀 and more text
},
expectedResponses: []CompletionResponse{
{Content: "Hello ", Done: false}, // Incomplete Unicode trimmed
{Content: "😀 world!", Done: true}, // Complete emoji in second response
},
description: "Unicode split across responses should be handled correctly",
},
{
name: "incomplete_unicode_buffered",
inputResponses: []CompletionResponse{
{Content: "Test " + string([]byte{0xF0, 0x9F}), Done: false}, // Incomplete emoji
{Content: string([]byte{0x98, 0x80}), Done: false}, // Complete the emoji
{Content: " done", Done: true},
},
expectedResponses: []CompletionResponse{
{Content: "Test ", Done: false}, // First part without incomplete unicode
{Content: "😀", Done: false}, // Complete emoji
{Content: " done", Done: true},
},
description: "Incomplete unicode should be buffered and combined with next response",
},
{
name: "empty_response_with_done",
inputResponses: []CompletionResponse{
{Content: "Complete response", Done: false},
{Content: "", Done: true}, // Empty response with Done=true
},
expectedResponses: []CompletionResponse{
{Content: "Complete response", Done: false},
{Content: "", Done: true}, // Should still be sent because Done=true
},
description: "Empty final response with Done=true should still be sent",
},
{
name: "done_reason_preserved",
inputResponses: []CompletionResponse{
{Content: "Response", Done: false},
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
},
expectedResponses: []CompletionResponse{
{Content: "Response", Done: false},
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
},
description: "DoneReason should be preserved in the final response",
},
{
name: "only_incomplete_unicode_not_done",
inputResponses: []CompletionResponse{
{Content: string([]byte{0xF0, 0x9F}), Done: false}, // Only incomplete unicode
},
expectedResponses: []CompletionResponse{
// No response expected - should be buffered
},
description: "Response with only incomplete unicode should be buffered if not done",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var actualResponses []CompletionResponse
// Create a callback that collects responses
callback := func(resp CompletionResponse) {
actualResponses = append(actualResponses, resp)
}
// Create the unicode buffer handler
handler := unicodeBufferHandler(callback)
// Send all input responses through the handler
for _, resp := range tt.inputResponses {
handler(resp)
}
// Verify the number of responses
if len(actualResponses) != len(tt.expectedResponses) {
t.Fatalf("%s: got %d responses, want %d responses",
tt.description, len(actualResponses), len(tt.expectedResponses))
}
// Verify each response matches the expected one
for i, expected := range tt.expectedResponses {
if i >= len(actualResponses) {
t.Fatalf("%s: missing response at index %d", tt.description, i)
continue
}
actual := actualResponses[i]
// Verify content
if actual.Content != expected.Content {
t.Errorf("%s: response[%d].Content = %q, want %q",
tt.description, i, actual.Content, expected.Content)
}
// Verify Done flag
if actual.Done != expected.Done {
t.Errorf("%s: response[%d].Done = %v, want %v",
tt.description, i, actual.Done, expected.Done)
}
// Verify DoneReason if specified
if actual.DoneReason != expected.DoneReason {
t.Errorf("%s: response[%d].DoneReason = %v, want %v",
tt.description, i, actual.DoneReason, expected.DoneReason)
}
}
})
}
}

View File

@@ -253,7 +253,6 @@ type Tensor interface {
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Div(ctx Context, t2 Tensor) Tensor
@@ -277,7 +276,6 @@ type Tensor interface {
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
@@ -299,12 +297,6 @@ type Tensor interface {
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
Mean(ctx Context) Tensor
Variance(ctx Context) Tensor
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -297,9 +297,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
}
case contains(t.Name, "cls", "output", "output_norm",
"altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
@@ -355,24 +353,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
bbs[c] = b
}
// Mimic llama runner logs summarizing layers and memory
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", max(0, params.NumGPULayers-1)))
gpuLayers := 0
switch C.ggml_backend_dev_type(output.d) {
case 0: // CPU
slog.Info("offloading output layer to CPU")
case 1: // GPU
slog.Info("offloading output layer to GPU")
gpuLayers++
case 2: // ACCEL
slog.Info("offloading output layer to ACCEL")
}
for _, layer := range layers {
if C.ggml_backend_dev_type(layer.d) == 1 {
gpuLayers++
}
}
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1))
for bs := range maps.Values(bbs) {
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
@@ -622,9 +602,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
needSync := true
@@ -913,13 +891,6 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
@@ -1227,13 +1198,6 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1309,42 +1273,3 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
Sqr(ctx).
SumRows(ctx).
Scale(ctx, 1/float64(t.Dim(0)))
}
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
return t.Variance(ctx).Sqrt(ctx)
}
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}

View File

@@ -573,16 +573,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent, dir_path);
// Avoid mixed hip+cuda configurations
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
if (!hip_devices && !rocr_devices) {
ggml_backend_load_best("cuda", silent, dir_path);
} else {
ggml_backend_load_best("hip", silent, dir_path);
}
ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);

View File

@@ -362,26 +362,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll

View File

@@ -35,7 +35,6 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
@@ -2323,9 +2322,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
@@ -3215,7 +3211,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;

View File

@@ -1,19 +0,0 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View File

@@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -1,9 +1,25 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -19,8 +35,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}

View File

@@ -1,4 +1,5 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -3434,61 +3434,31 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
float row_sum = 0;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
dst_row[0] = row_sum;
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@@ -489,7 +489,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1437,7 +1436,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1636,7 +1634,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2365,30 +2362,11 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
switch (dst->op) {
case GGML_OP_SUM_ROWS:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
break;
case GGML_OP_MEAN:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2418,12 +2396,11 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{

View File

@@ -956,61 +956,31 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
float row_sum = 0;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
dst_row[0] = row_sum;
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@@ -1,51 +0,0 @@
package gemma3n
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*TextModel
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.TextModel.Forward(ctx, batch, m.Cache)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewCausalCache(m.Shift),
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
)
return &m, nil
}
func init() {
model.Register("gemma3n", New)
}

View File

@@ -1,360 +0,0 @@
package gemma3n
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
type TextModel struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
*PerLayerProjector
AltupEmbd *nn.Linear `gguf:"altup_proj"`
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
TextLayers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
TextOptions
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
// Create a tensor of a single float32 value of 1.0 to use for altup correction
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates := inputs.Concat(ctx, altupProj, 2)
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
for i, layer := range m.TextLayers {
if i < firstSharedKeyValue {
cache.SetLayer(i)
} else if m.isLocal(i) {
cache.SetLayer(firstSharedKeyValue - 2)
} else {
cache.SetLayer(firstSharedKeyValue - 1)
}
var layerType int
ropeBase := m.ropeBase
if m.isLocal(i) {
layerType = 1
ropeBase = m.ropeBaseLocal
}
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :]
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
}
// hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.ropeBase
if m.isLocal(layer) {
ropeBase = m.ropeBaseLocal
}
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
*nn.Embedding
}
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
}
type PerLayerProjector struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
Projector *nn.Linear `gguf:"per_layer_model_proj"`
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
}
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
perLayerProjection := p.Projector.Forward(ctx, inputs)
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
if inputsPerLayer != nil {
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
}
return perLayerProjection
}
type TextLayer struct {
*AltUp
*Laurel
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *TextAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
PerLayerProjection *nn.Linear `gguf:"proj"`
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
}
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
predictions := d.Predict(ctx, hiddenStates, opts)
active := opts.altupActive(ctx, predictions)
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
laurel := d.Laurel.Forward(ctx, attn, opts)
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
attn = active.Add(ctx, attn)
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
mlp = attn.Add(ctx, mlp)
predictions = d.Correct(ctx, predictions, mlp, one, opts)
active = opts.altupActive(ctx, predictions)
if opts.altupCorrectScale {
active = d.ScaleCorrectedOutput(ctx, active)
}
active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx)
active = active.Mul(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
return predictions0.Concat(ctx, active, 2)
}
type AltUp struct {
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
Router *nn.Linear `gguf:"altup_router"`
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
}
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
}
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
predictions := coefficients.Mulmat(ctx, hiddenStates)
predictions = predictions.Add(ctx, hiddenStates)
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
modalities := a.computeRouterModalities(ctx, activated, opts)
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Add(ctx, one)
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
corrected := innovation.Mul(ctx, coefficients)
corrected = corrected.Add(ctx, predictions)
return corrected
}
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
return predictions.Mul(ctx, a.CorrectionScale)
}
type Laurel struct {
LinearLeft *nn.Linear `gguf:"laurel_l"`
LinearRight *nn.Linear `gguf:"laurel_r"`
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
}
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates.Add(ctx, residual)
}
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
upStates := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
if activationSparsityScale > 0 {
mean := hiddenStates.Mean(ctx)
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
cutoff := mean.Add(ctx, std)
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
type TextOptions struct {
hiddenLayers int
hiddenSize int
hiddenSizePerLayerInput int
numHeads, numKVHeads int
keyLength, valueLength int
sharedKeyValueLayers int
altupActiveIndex int
altupInputs int
altupCorrectScale bool
eps float32
ropeBase float32
ropeBaseLocal float32
ropeScale float32
slidingWindowPattern []bool
activationSparsityScale []float32
}
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
}
func (o *TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o *TextOptions) isLocal(i int) bool {
return o.slidingWindowPattern[i]
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
TextLayers: make([]TextLayer, c.Uint("block_count")),
TextOptions: TextOptions{
hiddenLayers: int(c.Uint("block_count")),
hiddenSize: int(c.Uint("embedding_length")),
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
altupActiveIndex: int(c.Uint("altup.active_idx")),
altupInputs: int(c.Uint("altup.num_inputs")),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.freq_scale", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"),
},
}
}

View File

@@ -3,7 +3,6 @@ package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

View File

@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}

View File

@@ -1,16 +0,0 @@
package model
import "testing"
func TestVocabulary_SpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
}
specialVocab := vocab.SpecialVocabulary()
if len(specialVocab) != 4 {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}

View File

@@ -2,6 +2,8 @@ package common
import (
"strings"
"github.com/ollama/ollama/llm"
)
func FindStop(sequence string, stops []string) (bool, string) {
@@ -29,68 +31,41 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
var sequence string
for _, resp := range resps {
sequence += resp.Content
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
idx := strings.Index(sequence, stop)
if idx < 0 {
return resps, false
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
truncated := sequence[:idx]
if len(truncated) == 0 {
return nil, true
}
result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if len(chunk) < len(resp.Content) {
truncationHappened = true
}
result = append(result, joined[start:end])
start = end
if len(chunk) > 0 {
result = append(result, llm.CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
}
return result, tokenTruncated
}
func IncompleteUnicode(token string) bool {
incomplete := false
// check if there is incomplete UTF-8 character at the end
for i := 1; i < 5 && i <= len(token); i++ {
c := token[len(token)-i]
if (c & 0xc0) == 0x80 {
// continuation byte: 10xxxxxx
continue
}
if (c & 0xe0) == 0xc0 {
// 2-byte character: 110xxxxx ...
incomplete = i < 2
} else if (c & 0xf0) == 0xe0 {
// 3-byte character: 1110xxxx ...
incomplete = i < 3
} else if (c & 0xf8) == 0xf0 {
// 4-byte character: 11110xxx ...
incomplete = i < 4
}
// else 1-byte character or invalid byte
break
}
return incomplete
return result, truncationHappened
}

View File

@@ -1,51 +1,84 @@
package common
import (
"fmt"
"reflect"
"testing"
"github.com/ollama/ollama/llm"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
pieces []llm.CompletionResponse
stop string
expected []string
expected []llm.CompletionResponse
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: "world"},
},
stop: "world",
expected: []llm.CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wor"},
},
stop: "or",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
name: "Suffix partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wo"},
},
stop: "llo w",
expected: []llm.CompletionResponse{
{Content: "He"},
},
expectedTrunc: true,
},
}
@@ -54,76 +87,23 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
}
})
}
}
func TestIncompleteUnicode(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Basic",
input: "hi",
expected: false,
},
{
name: "Two byte",
input: "hi" + string([]byte{0xc2, 0xa3}),
expected: false,
},
{
name: "Two byte - missing last",
input: "hi" + string([]byte{0xc2}),
expected: true,
},
{
name: "Three byte",
input: "hi" + string([]byte{0xe0, 0xA0, 0x80}),
expected: false,
},
{
name: "Three byte - missing last",
input: "hi" + string([]byte{0xe0, 0xA0}),
expected: true,
},
{
name: "Three byte - missing last 2",
input: "hi" + string([]byte{0xe0}),
expected: true,
},
{
name: "Four byte",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}),
expected: false,
},
{
name: "Four byte - missing last",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a}),
expected: true,
},
{
name: "Four byte - missing last 2",
input: "hi" + string([]byte{0xf0, 0x92}),
expected: true,
},
{
name: "Four byte - missing last 3",
input: "hi" + string([]byte{0xf0}),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IncompleteUnicode(tt.input)
if result != tt.expected {
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
}
})
func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}

View File

@@ -17,7 +17,6 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/sync/semaphore"
@@ -52,13 +51,13 @@ type Sequence struct {
pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -89,6 +88,19 @@ type Sequence struct {
numPromptInputs int
}
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
if len(resp.Content) > 0 || resp.Done {
select {
case seq.responses <- resp:
// Successfully sent
return true
case <-seq.quit:
return false
}
}
return true
}
type NewSequenceParams struct {
numPredict int
stop []string
@@ -147,8 +159,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
@@ -272,36 +284,15 @@ func (s *Server) allNil() bool {
return true
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
// Send any remaining pending responses
for _, resp := range seq.pendingResponses {
seq.send(resp)
}
seq.pendingResponses = []llm.CompletionResponse{}
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
@@ -490,8 +481,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@@ -523,13 +517,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
if common.IncompleteUnicode(sequence) {
continue
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
for _, resp := range seq.pendingResponses {
if !seq.send(resp) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
break
}
}
seq.pendingResponses = []llm.CompletionResponse{}
}
return nil
@@ -627,9 +621,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@@ -20,7 +20,6 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore"
@@ -56,13 +55,13 @@ type Sequence struct {
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -94,6 +93,19 @@ type Sequence struct {
numPromptInputs int
}
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
if len(resp.Content) > 0 || resp.Done {
select {
case seq.responses <- resp:
// Successfully sent
return true
case <-seq.quit:
return false
}
}
return true
}
type NewSequenceParams struct {
numPredict int
stop []string
@@ -167,8 +179,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
@@ -313,36 +325,15 @@ func (s *Server) allNil() bool {
return true
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
// Send any remaining pending responses
for _, resp := range seq.pendingResponses {
seq.send(resp)
}
seq.pendingResponses = []llm.CompletionResponse{}
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
@@ -541,8 +532,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@@ -574,13 +568,14 @@ func (s *Server) processBatch() error {
continue
}
if common.IncompleteUnicode(sequence) {
continue
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
// Send all pending responses directly without unicode checking
for _, resp := range seq.pendingResponses {
if !seq.send(resp) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
break
}
}
seq.pendingResponses = []llm.CompletionResponse{}
}
return nil
@@ -683,9 +678,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@@ -27,6 +27,7 @@ function checkEnv() {
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
}
# Locate CUDA versions
# Note: this assumes every version found will be built
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
if ($cudaList.length -eq 0) {
$d=(get-command -ea 'silentlycontinue' nvcc).path
@@ -93,6 +94,19 @@ function buildOllama() {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
write-host "Building CUDA v11 backend libraries"
# Note: cuda v11 requires msvc 2019 so force the older generator
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
@@ -113,17 +127,12 @@ function buildOllama() {
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
$env:HIP_PLATFORM="amd"
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
& cmake --fresh --preset "ROCm 6" -G Ninja `
-DCMAKE_C_COMPILER=clang `
-DCMAKE_CXX_COMPILER=clang++ `
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
--install-prefix $script:DIST_DIR
& cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
& cmake --build --preset "ROCm 6" --config Release --parallel $script:JOBS
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}

View File

@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=GOFLAGS \
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
--build-arg=CUDA_V11_ARCHITECTURES \
--build-arg=CUDA_V12_ARCHITECTURES \
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
--build-arg=OLLAMA_FAST_BUILD \

View File

@@ -59,7 +59,7 @@ type DiskCache struct {
testHookBeforeFinalWrite func(f *os.File)
}
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}

View File

@@ -231,8 +231,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// do not quantize relative position bias (T5)
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
newType := fsggml.TensorType(t.Kind)
if quantize {
// get more optimal quantization type based on the tensor shape, layer, etc.

View File

@@ -191,7 +191,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
if err != nil {
pending.errCh <- err
break

View File

@@ -310,23 +310,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message (except for tool messages which preserve individual metadata).
// collate also collects and returns all system messages.
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
// todo(parthsareen): revisit for contextual image support
func collate(msgs []api.Message) (string, []*api.Message) {
var system []string
var collated []*api.Message
for i := range msgs {
if msgs[i].Role == "system" {
system = append(system, msgs[i].Content)
msg := msgs[i]
if msg.Role == "system" {
system = append(system, msg.Content)
}
// merges consecutive messages of the same role into a single message (except for tool messages)
if len(collated) > 0 && collated[len(collated)-1].Role == msgs[i].Role && msgs[i].Role != "tool" {
collated[len(collated)-1].Content += "\n\n" + msgs[i].Content
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msgs[i])
collated = append(collated, &msg)
}
}

View File

@@ -163,12 +163,10 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- else if eq .Role "tool" }}TOOL:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
@@ -378,99 +376,3 @@ func TestExecuteWithSuffix(t *testing.T) {
})
}
}
func TestCollate(t *testing.T) {
cases := []struct {
name string
msgs []api.Message
expected []*api.Message
system string
}{
{
name: "consecutive user messages are merged",
msgs: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "user", Content: "How are you?"},
},
expected: []*api.Message{
{Role: "user", Content: "Hello\n\nHow are you?"},
},
system: "",
},
{
name: "consecutive tool messages are NOT merged",
msgs: []api.Message{
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
expected: []*api.Message{
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
system: "",
},
{
name: "tool messages preserve all fields",
msgs: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
expected: []*api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
},
system: "",
},
{
name: "mixed messages with system",
msgs: []api.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
{Role: "user", Content: "Thanks"},
},
expected: []*api.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
{Role: "user", Content: "Thanks"},
},
system: "You are helpful",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
system, collated := collate(tt.msgs)
if diff := cmp.Diff(system, tt.system); diff != "" {
t.Errorf("system mismatch (-got +want):\n%s", diff)
}
// Compare the messages
if len(collated) != len(tt.expected) {
t.Errorf("expected %d messages, got %d", len(tt.expected), len(collated))
return
}
for i := range collated {
if collated[i].Role != tt.expected[i].Role {
t.Errorf("message %d role mismatch: got %q, want %q", i, collated[i].Role, tt.expected[i].Role)
}
if collated[i].Content != tt.expected[i].Content {
t.Errorf("message %d content mismatch: got %q, want %q", i, collated[i].Content, tt.expected[i].Content)
}
if collated[i].ToolName != tt.expected[i].ToolName {
t.Errorf("message %d tool name mismatch: got %q, want %q", i, collated[i].ToolName, tt.expected[i].ToolName)
}
}
})
}
}

View File

@@ -18,8 +18,9 @@ const (
)
type Parser struct {
tag string
tools []api.Tool
tag string
names []string
properties []string
state toolsState
buffer []byte
@@ -33,10 +34,15 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
}
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
return &Parser{
tag: tag,
tools: tools,
var p Parser
for _, t := range tools {
p.names = append(p.names, t.Function.Name)
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
}
}
p.tag = tag
return &p
}
// Add processes a string input to parse tool calls and content that
@@ -115,40 +121,36 @@ func (p *Parser) findTag() (int, bool) {
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
var tool *api.Tool
var name string
var args map[string]any
var end int = len(p.buffer)
var i int
// find tool name
for _, t := range p.tools {
n := t.Function.Name
var i int
for _, n := range p.names {
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
tool = &t
name = n
end = i + len(n)
}
}
}
if tool == nil {
if name == "" {
return nil
}
// only look for arguments after the tool name if the tool has parameters
// TODO (jmorganca): while probably uncommon, this doesn't support
// parsing arguments before the tool name, which may be needed in the future
args := map[string]any{}
if len(tool.Function.Parameters.Properties) > 0 {
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
return nil
}
if args, i = p.findArguments(); args == nil {
return nil
}
end += i
if i > end {
end = i
}
tc := &api.ToolCall{
Function: api.ToolCallFunction{
Name: tool.Function.Name,
Name: name,
Arguments: args,
Index: p.n,
},
@@ -160,14 +162,10 @@ func (p *Parser) parseToolCall() *api.ToolCall {
}
// findArguments returns the first object that appears to be
// arguments for the provided tool in the provided buffer,
// returning nil if no arguments are found.
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
if len(buffer) == 0 {
// arguments and the position where the arguments end, returning nil and 0 if
// an invalid JSON object or non-arguments object is found first
func (p *Parser) findArguments() (map[string]any, int) {
if len(p.buffer) == 0 {
return nil, 0
}
@@ -177,7 +175,7 @@ func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
var object []byte
// find any outer json object
for i, c := range buffer {
for i, c := range p.buffer {
if c == '{' {
braces++
if start == -1 {
@@ -186,13 +184,11 @@ func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
}
if c == '}' {
if start != -1 {
braces--
if braces == 0 {
end = i + 1
object = buffer[start:end]
break
}
braces--
if braces == 0 && start != -1 {
end = i + 1
object = p.buffer[start:end]
break
}
}
}
@@ -202,45 +198,32 @@ func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
}
var data map[string]any
// not valid json
if err := json.Unmarshal(object, &data); err != nil {
return nil, 0
}
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch obj := obj.(type) {
switch v := obj.(type) {
case map[string]any:
valid := true
// check if all keys in the object exist in the tool's parameters
for key := range obj {
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
valid = false
break
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
}
}
// check for required parameters
// TODO (jmorganca): this should error instead of silently failing
if valid {
for _, required := range tool.Function.Parameters.Required {
if _, exists := obj[required]; !exists {
valid = false
break
}
}
}
if valid {
return obj
}
for _, value := range obj {
for _, value := range v {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range obj {
for _, item := range v {
if result := find(item); result != nil {
return result
}

View File

@@ -52,8 +52,7 @@ func TestParser(t *testing.T) {
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"city"},
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
@@ -105,13 +104,6 @@ func TestParser(t *testing.T) {
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello",
},
},
}
tests := []struct {
@@ -152,35 +144,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "invalid arguments",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "empty args",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "missing required args",
inputs: []string{`<tool_call>{"name": "get_temperature", "arguments": {}}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "text before tool call",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
@@ -198,28 +161,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "qwen no args tool call",
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
content: "Let me say hello to the user. I'll use the say_hello tool ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "qwen no args with text",
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
content: "Let me say hello to the user. I'll use the say_hello tool. ",
tmpl: qwen,
calls: nil,
},
{
name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
@@ -248,7 +189,7 @@ func TestParser(t *testing.T) {
},
},
{
name: "qwen two tool calls",
name: "two tool calls",
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
content: "Okay, let's call both tools! ",
tmpl: qwen,
@@ -274,55 +215,6 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "empty args followed by args",
inputs: []string{`Let me say hello and check the weather. <tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call>`},
content: "Let me say hello and check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
},
},
},
},
{
name: "qwen empty followed by args",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_conditions", "arguments": {}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
content: "Let me check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "deepseek",
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
@@ -446,52 +338,6 @@ func TestParser(t *testing.T) {
content: "for { fmt.Println(\"hello\") }",
tmpl: json,
},
{
name: "json no args tool call",
inputs: []string{
"{\"name\": \"say_hello\"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "json no args no tool call",
inputs: []string{
"I'll use the say_hello tool to say hello to the user.",
},
content: "I'll use the say_hello tool to say hello to the user.",
tmpl: json,
calls: nil,
},
// TODO (jmorganca): this is a false positive, we should
// not be parsing this as a tool call
{
name: "json no args false positive",
inputs: []string{
`{say_hello!!!}`,
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "list multiple",
inputs: []string{
@@ -534,30 +380,6 @@ func TestParser(t *testing.T) {
},
{
name: "list partial",
inputs: []string{
"[{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list invalid",
inputs: []string{
"[",
"{",
@@ -571,33 +393,6 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "list trailing ]",
inputs: []string{
"[",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
"]",
"]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list not a tool call",
inputs: []string{
@@ -609,26 +404,6 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "list with no arguments",
inputs: []string{
"[",
"{",
"\"name\": \"say_hello\"",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
}
for _, tt := range tests {
@@ -925,75 +700,25 @@ func TestFindTag(t *testing.T) {
}
func TestFindArguments(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"format": {
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the temperature for",
},
},
},
},
}
tool2 := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello to the user",
},
}
tests := []struct {
name string
buffer []byte
want map[string]any
tool api.Tool
}{
{
name: "empty string",
buffer: []byte{},
want: nil,
tool: tool,
},
{
name: "whitespace only",
buffer: []byte(" \n\t "),
want: nil,
tool: tool,
},
{
name: "unbalanced braces - missing closing",
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
want: nil,
tool: tool,
},
{
name: "unbalanced braces - extra closing",
@@ -1001,13 +726,11 @@ func TestFindArguments(t *testing.T) {
want: map[string]any{
"format": "fahrenheit",
},
tool: tool,
},
{
name: "invalid JSON",
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
want: nil,
tool: tool,
},
{
name: "valid json",
@@ -1016,7 +739,6 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "valid arguments with special tokens",
@@ -1025,7 +747,6 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "valid arguments in array",
@@ -1034,7 +755,6 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "nested deep",
@@ -1043,52 +763,39 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "one arg",
buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`),
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
want: map[string]any{
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "two args",
buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
want: map[string]any{
"location": "San Francisco, CA",
"format": "fahrenheit",
},
tool: tool,
},
{
name: "no args",
buffer: []byte(`{"name": "say_hello"}`),
want: nil,
tool: tool2,
},
{
name: "deepseek",
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{
"location": "Tokyo",
},
tool: tool,
},
{
name: "deepseek",
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
want: map[string]any{
"location": "Tokyo",
},
tool: tool,
},
}
for _, tt := range tests {
parser := &Parser{
buffer: tt.buffer,
properties: []string{"format", "location"},
}
t.Run(tt.name, func(t *testing.T) {
got, _ := findArguments(tt.tool, tt.buffer)
got, _ := parser.findArguments()
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)