mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
52 Commits
brucemacd/
...
v0.9.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44b17d2bfa | ||
|
|
3b8b692218 | ||
|
|
4129af9205 | ||
|
|
45f216a9c7 | ||
|
|
d0b32def60 | ||
|
|
11ffc36157 | ||
|
|
ba04902670 | ||
|
|
3944602f51 | ||
|
|
73b642e6f3 | ||
|
|
ad118d8b13 | ||
|
|
f08534137b | ||
|
|
4b4a90f233 | ||
|
|
03274a6b2f | ||
|
|
cc6463ebca | ||
|
|
405d2f628f | ||
|
|
a3f7dd3e98 | ||
|
|
c85c0ebf89 | ||
|
|
10a8e04a8d | ||
|
|
1c6669e64c | ||
|
|
b2b270ad5d | ||
|
|
2bb69b40c7 | ||
|
|
65bff664cb | ||
|
|
c088ac0e79 | ||
|
|
0a066cfd91 | ||
|
|
87b7af6cee | ||
|
|
f2527b08fb | ||
|
|
8bcb3125c1 | ||
|
|
6baf1e31e2 | ||
|
|
ed567ef43b | ||
|
|
a6e64fbdf2 | ||
|
|
60cfa2a203 | ||
|
|
55bbf3b4a1 | ||
|
|
6bda1d2479 | ||
|
|
9e125d884c | ||
|
|
a6fbfc880c | ||
|
|
502028968d | ||
|
|
5a8eb0e151 | ||
|
|
9f8a18ec05 | ||
|
|
6b04cad7e8 | ||
|
|
45f56355d5 | ||
|
|
0dabb4ef6a | ||
|
|
2e77aa1ae7 | ||
|
|
deaabe292d | ||
|
|
af21a5ac39 | ||
|
|
f63d7f68eb | ||
|
|
82ad1dbc07 | ||
|
|
feeabdadd2 | ||
|
|
fc0309615e | ||
|
|
09d308d6b6 | ||
|
|
20c5fd39c8 | ||
|
|
d2ee599dcf | ||
|
|
6ed8898590 |
174
.github/workflows/release.yaml
vendored
174
.github/workflows/release.yaml
vendored
@@ -54,48 +54,6 @@ jobs:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: dist/*
|
||||
|
||||
darwin-sign:
|
||||
runs-on: macos-13
|
||||
environment: release
|
||||
needs: darwin-build
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: |
|
||||
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
|
||||
security create-keychain -p password build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p password build.keychain
|
||||
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
security set-keychain-settings -lut 3600 build.keychain
|
||||
env:
|
||||
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
|
||||
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-amd64
|
||||
path: dist/darwin-amd64
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-arm64
|
||||
path: dist/darwin-arm64
|
||||
- run: |
|
||||
export VERSION=${GITHUB_REF_NAME#v}
|
||||
./scripts/build_darwin.sh sign macapp
|
||||
env:
|
||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: |
|
||||
dist/Ollama-darwin.zip
|
||||
dist/ollama-darwin.tgz
|
||||
|
||||
windows-depends:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -103,21 +61,18 @@ jobs:
|
||||
arch: [amd64]
|
||||
preset: ['CPU']
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 11'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
cuda-version: '11.3'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -160,6 +115,9 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
@@ -178,9 +136,9 @@ jobs:
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}"
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
env:
|
||||
@@ -230,61 +188,11 @@ jobs:
|
||||
go-version-file: go.mod
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
||||
- run: |
|
||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
||||
& .\scripts\build_windows.ps1 buildApp
|
||||
env:
|
||||
VCToolsRedistDir: stub
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: |
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||
|
||||
windows-sign:
|
||||
runs-on: windows-2022
|
||||
environment: release
|
||||
needs: [windows-depends, windows-build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: google-github-actions/auth@v2
|
||||
with:
|
||||
project_id: ollama
|
||||
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
|
||||
- run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
|
||||
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
|
||||
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
|
||||
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
|
||||
|
||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: build-windows-*
|
||||
path: dist\
|
||||
merge-multiple: true
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows-amd64-*
|
||||
path: dist\windows-amd64\
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
||||
env:
|
||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: |
|
||||
dist\OllamaSetup.exe
|
||||
dist\ollama-windows-*.zip
|
||||
|
||||
linux-build:
|
||||
strategy:
|
||||
@@ -322,16 +230,21 @@ jobs:
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
esac
|
||||
done
|
||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
echo "Manifests"
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
|
||||
echo $ARCHIVE
|
||||
cat $ARCHIVE
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
@@ -436,48 +349,16 @@ jobs:
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
runs-on: linux
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends, linux-build]
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: dist-linux-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
- name: Create or update Release
|
||||
- name: Create or update Release for tag
|
||||
run: |
|
||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
|
||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||
if [ -n "$OLD_TAG" ]; then
|
||||
@@ -491,5 +372,12 @@ jobs:
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
fi
|
||||
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
|
||||
gh release upload ${GITHUB_REF_NAME} dist/* --clobber
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"publish\": \"1\"}}"
|
||||
|
||||
17
.github/workflows/test.yaml
vendored
17
.github/workflows/test.yaml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
@@ -78,11 +78,11 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
@@ -120,6 +120,9 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -133,8 +136,8 @@ jobs:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
env:
|
||||
|
||||
@@ -78,14 +78,13 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -116,7 +115,11 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
RUNTIME_DEPENDENCY_SET rocm
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
|
||||
@@ -17,20 +17,12 @@
|
||||
"name": "CUDA",
|
||||
"inherits": [ "Default" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -58,6 +50,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
@@ -78,11 +71,6 @@
|
||||
"configurePreset": "CUDA",
|
||||
"targets": [ "ggml-cuda" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 11"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
24
Dockerfile
24
Dockerfile
@@ -7,12 +7,13 @@ ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
|
||||
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& dnf install -y ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
@@ -38,15 +39,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.3
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
@@ -98,17 +90,15 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
|
||||
@@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||
|
||||
```shell
|
||||
ollama run llama3.2
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
## Model library
|
||||
@@ -407,6 +407,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -451,6 +454,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
13
api/types.go
13
api/types.go
@@ -457,13 +457,12 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.Contains(link, "Ollama.app") {
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
|
||||
@@ -47,7 +47,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "hidden")
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
|
||||
@@ -190,6 +190,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
||||
165
convert/convert_gemma3n.go
Normal file
165
convert/convert_gemma3n.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"gonum.org/v1/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
type gemma3nModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextModel struct {
|
||||
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, t := range m.TextModel.LayerTypes {
|
||||
if !yield(t == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out, ts := mergeTensors(ts,
|
||||
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||
)
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "audio_tower"),
|
||||
strings.Contains(t.Name(), "embed_audio"),
|
||||
strings.Contains(t.Name(), "vision_tower"),
|
||||
strings.Contains(t.Name(), "embed_vision"):
|
||||
// TODO: handle audio and vision towers
|
||||
continue
|
||||
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||
if m.TextModel.AltupCoefClip > 0 {
|
||||
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
"altup.", "altup_",
|
||||
"modality_router", "router",
|
||||
"prediction_coefs", "predict_coef",
|
||||
"correction_coefs", "correct_coef",
|
||||
"correct_output_scale", "correct_scale.weight",
|
||||
"laurel.", "laurel_",
|
||||
"linear_left", "l",
|
||||
"linear_right", "r",
|
||||
"post_laurel_norm", "post_norm",
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -30,65 +27,38 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
"w2", "ffn_down_exps",
|
||||
"w3", "ffn_up_exps",
|
||||
}
|
||||
|
||||
for i := range p.NumLocalExperts {
|
||||
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
||||
}
|
||||
|
||||
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
||||
namer := strings.NewReplacer(oldnew...)
|
||||
experts := make(map[string]experts)
|
||||
|
||||
// merge experts into a single tensor while removing them from ts
|
||||
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
||||
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
||||
return false
|
||||
}
|
||||
|
||||
name := namer.Replace(t.Name())
|
||||
experts[name] = append(experts[name], t)
|
||||
return true
|
||||
})
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
WriterTo: e,
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
||||
for i := range p.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
||||
})
|
||||
}
|
||||
|
||||
out, ts := mergeTensors(ts, merges...)
|
||||
return append(out, p.llamaModel.Tensors(ts)...)
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Replacements() []string {
|
||||
return append(
|
||||
p.llamaModel.Replacements(),
|
||||
"model.layers", "blk",
|
||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||
"block_sparse_moe.experts.", ".",
|
||||
)
|
||||
}
|
||||
|
||||
type experts []Tensor
|
||||
|
||||
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
||||
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
||||
for _, t := range e {
|
||||
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
||||
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
||||
// this accomplishes the same thing by writing each expert tensor in sequence
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
for _, t := range ts {
|
||||
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||
for t := range splitDim(t, 2,
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
|
||||
) {
|
||||
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||
out = append(out, t)
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
strings.NewReplacer("attn.qkv", "attn_q"),
|
||||
strings.NewReplacer("attn.qkv", "attn_k"),
|
||||
strings.NewReplacer("attn.qkv", "attn_v"),
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
|
||||
@@ -1,56 +1,129 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
}
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
// is split evenly based on the number of replacers provided.
|
||||
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
|
||||
// is split evenly based on the number of replacers provided unless a specific count is given.
|
||||
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
return func(yield func(*ggml.Tensor) bool) {
|
||||
for i, replacer := range replacers {
|
||||
var offset int
|
||||
for _, split := range splits {
|
||||
t := t.Clone()
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = shape[dim] / uint64(len(replacers))
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if split.fn != nil {
|
||||
tt, err = split.fn(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
|
||||
if !yield(&ggml.Tensor{
|
||||
Name: replacer.Replace(t.Name()),
|
||||
Name: split.Replace(t.Name()),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: tt,
|
||||
WriterTo: t,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
pattern, name string
|
||||
}
|
||||
|
||||
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
||||
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
||||
var matched []Tensor
|
||||
for i := range merges {
|
||||
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
||||
matched, _ := path.Match(merges[i].pattern, t.Name())
|
||||
return matched
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
Kind: matched[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
||||
WriterTo: mergeGroup(matched),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out, unmatched
|
||||
}
|
||||
|
||||
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
||||
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
||||
for _, e := range s {
|
||||
if fn(e) {
|
||||
matched = append(matched, e)
|
||||
} else {
|
||||
unmatched = append(unmatched, e)
|
||||
}
|
||||
}
|
||||
|
||||
return matched, unmatched
|
||||
}
|
||||
|
||||
type mergeGroup []Tensor
|
||||
|
||||
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||
for _, t := range g {
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
402
convert/tensor_test.go
Normal file
402
convert/tensor_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
)
|
||||
|
||||
type fakeTensor struct {
|
||||
name string
|
||||
shape []uint64
|
||||
data []float32
|
||||
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (f fakeTensor) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f fakeTensor) Shape() []uint64 {
|
||||
return f.shape
|
||||
}
|
||||
|
||||
func (f fakeTensor) Kind() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
||||
f.repacker = fn
|
||||
}
|
||||
|
||||
func (f fakeTensor) Clone() Tensor {
|
||||
return &fakeTensor{
|
||||
name: f.name,
|
||||
shape: slices.Clone(f.shape),
|
||||
data: slices.Clone(f.data),
|
||||
repacker: f.repacker,
|
||||
}
|
||||
}
|
||||
|
||||
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
|
||||
data := f.data
|
||||
if f.repacker != nil {
|
||||
data, err = f.repacker(f.name, data, f.shape)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(data) * 4), nil
|
||||
}
|
||||
|
||||
func mul(shape []uint64) int {
|
||||
n := 1
|
||||
for _, dim := range shape {
|
||||
n *= int(dim)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func TestSplitDim(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
|
||||
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
|
||||
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
|
||||
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
|
||||
t.Fatal("expected shape [2, 4], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
|
||||
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
|
||||
t.Fatal("expected shape [1, 4], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
|
||||
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
|
||||
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
unmatched := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "a.0.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "a.1.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.0.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.1.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "e.0.f",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
|
||||
},
|
||||
}
|
||||
|
||||
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
|
||||
for i := range n {
|
||||
got := matched[i]
|
||||
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
|
||||
t.Errorf("unexpected (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, 20)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
offset := 10 + (i * 20)
|
||||
want := make([]float32, 20)
|
||||
for j := range 20 {
|
||||
want[j] = float32(offset + j)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, f32s); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("single merge", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
|
||||
if len(unmatched) != 3 {
|
||||
t.Error("expected 3 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 1, matched)
|
||||
})
|
||||
|
||||
t.Run("multiple merges", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
|
||||
if len(unmatched) != 1 {
|
||||
t.Error("expected 1 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 2 {
|
||||
t.Error("expected 2 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 2, matched)
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
|
||||
if len(unmatched) != 5 {
|
||||
t.Error("expected 5 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 0 {
|
||||
t.Error("expected no merged tensors, got", len(matched))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -55,10 +56,13 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return "sbsa"
|
||||
}
|
||||
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
// The detected driver is older than Feb 2023
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
return "v11"
|
||||
}
|
||||
return "v12"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
29
docs/api.md
29
docs/api.md
@@ -1157,15 +1157,11 @@ A single JSON object will be returned.
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
|
||||
"model": "codellama:13b",
|
||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
||||
"size": 7365960935,
|
||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
||||
"capabilities": [
|
||||
"completion"
|
||||
],
|
||||
|
||||
"name": "deepseek-r1:latest",
|
||||
"model": "deepseek-r1:latest",
|
||||
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
|
||||
"size": 4683075271,
|
||||
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
@@ -1178,16 +1174,11 @@ A single JSON object will be returned.
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
"model": "llama4:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
|
||||
"name": "llama3.2:latest",
|
||||
"model": "llama3.2:latest",
|
||||
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
|
||||
"size": 2019393189,
|
||||
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
||||
@@ -1,6 +1,6 @@
|
||||
# GPU
|
||||
## Nvidia
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
||||
|
||||
Check your compute compatibility to see if your card is supported:
|
||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||
|
||||
@@ -112,8 +112,8 @@ sudo systemctl status ollama
|
||||
> While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
> kernel source, the version is older and may not support all ROCm features. We
|
||||
> recommend you install the latest driver from
|
||||
> https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
> GPU.
|
||||
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support
|
||||
> of your Radeon GPU.
|
||||
|
||||
## Customizing
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ Ollama includes multiple LLM libraries compiled for different GPUs and CPU vecto
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||
```
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
|
||||
@@ -10,4 +10,5 @@ type Config interface {
|
||||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
}
|
||||
|
||||
116
fs/ggml/ggml.go
116
fs/ggml/ggml.go
@@ -34,7 +34,8 @@ func (kv KV) Kind() string {
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) FileType() FileType {
|
||||
@@ -53,16 +54,27 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||
if heads := kv.HeadCountMin(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
@@ -70,15 +82,11 @@ func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
@@ -90,40 +98,83 @@ func (kv KV) ChatTemplate() string {
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
return keyValue(kv, key, append(defaultValue, "")...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
|
||||
_, max := kv.UintOrArrayValue(key, defaultValue)
|
||||
return max
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
min, _ := kv.UintOrArrayValue(key, defaultValue)
|
||||
return min
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return u32, u32
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
}
|
||||
return uint32(min), uint32(max)
|
||||
}
|
||||
|
||||
return defaultValue, defaultValue
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
@@ -143,17 +194,17 @@ type arrayValueTypes interface {
|
||||
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key]; ok {
|
||||
return val.(T)
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
@@ -425,11 +476,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
@@ -504,7 +555,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
case "gemma", "gemma2", "gemma3", "gemma3n":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
@@ -517,6 +568,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
if f.KV().Architecture() == "gemma3n" {
|
||||
fullOffload *= 4
|
||||
partialOffload *= 4
|
||||
}
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
|
||||
@@ -269,3 +269,33 @@ func TestKeyValue(t *testing.T) {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadCount(t *testing.T) {
|
||||
valuesArray := []int32{1, 5, 3, 4}
|
||||
cases := []struct {
|
||||
kv KV
|
||||
want uint64
|
||||
}{
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
|
||||
},
|
||||
want: uint64(5),
|
||||
},
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": uint32(3),
|
||||
},
|
||||
want: uint64(3),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.kv.HeadCountMax()
|
||||
if got != tt.want {
|
||||
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := slices.Collect(maps.Keys(kv))
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
return -1
|
||||
} else {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var s uint64
|
||||
@@ -615,6 +609,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
case []bool:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||
case *array[bool]:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
|
||||
@@ -2,62 +2,82 @@ package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 608,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
347
fs/gguf/gguf.go
Normal file
347
fs/gguf/gguf.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
typeUint8 uint32 = iota
|
||||
typeInt8
|
||||
typeUint16
|
||||
typeInt16
|
||||
typeUint32
|
||||
typeInt32
|
||||
typeFloat32
|
||||
typeBool
|
||||
typeString
|
||||
typeArray
|
||||
typeUint64
|
||||
typeInt64
|
||||
typeFloat64
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("unsupported")
|
||||
|
||||
type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *bufferedReader
|
||||
bts []byte
|
||||
}
|
||||
|
||||
func Open(path string) (f *File, err error) {
|
||||
f = &File{bts: make([]byte, 4096)}
|
||||
f.file, err = os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.reader = newBufferedReader(f.file, 32<<10)
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||
}
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version < 2 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensors.successFunc = func() error {
|
||||
offset := f.reader.offset
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||
return nil
|
||||
}
|
||||
|
||||
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) readTensor() (TensorInfo, error) {
|
||||
name, err := readString(f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
dims, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
shape := make([]uint64, dims)
|
||||
for i := range dims {
|
||||
shape[i], err = read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
type_, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
offset, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
key, err := readString(f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
value, err := func() (any, error) {
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return read[uint8](f)
|
||||
case typeInt8:
|
||||
return read[int8](f)
|
||||
case typeUint16:
|
||||
return read[uint16](f)
|
||||
case typeInt16:
|
||||
return read[int16](f)
|
||||
case typeUint32:
|
||||
return read[uint32](f)
|
||||
case typeInt32:
|
||||
return read[int32](f)
|
||||
case typeUint64:
|
||||
return read[uint64](f)
|
||||
case typeInt64:
|
||||
return read[int64](f)
|
||||
case typeFloat32:
|
||||
return read[float32](f)
|
||||
case typeFloat64:
|
||||
return read[float64](f)
|
||||
case typeBool:
|
||||
return read[bool](f)
|
||||
case typeString:
|
||||
return readString(f)
|
||||
case typeArray:
|
||||
return readArray(f)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
return KeyValue{
|
||||
Key: key,
|
||||
Value: Value{value},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func read[T any](f *File) (t T, err error) {
|
||||
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func readArray(f *File) (any, error) {
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return readArrayData[uint8](f, n)
|
||||
case typeInt8:
|
||||
return readArrayData[int8](f, n)
|
||||
case typeUint16:
|
||||
return readArrayData[uint16](f, n)
|
||||
case typeInt16:
|
||||
return readArrayData[int16](f, n)
|
||||
case typeUint32:
|
||||
return readArrayData[uint32](f, n)
|
||||
case typeInt32:
|
||||
return readArrayData[int32](f, n)
|
||||
case typeUint64:
|
||||
return readArrayData[uint64](f, n)
|
||||
case typeInt64:
|
||||
return readArrayData[int64](f, n)
|
||||
case typeFloat32:
|
||||
return readArrayData[float32](f, n)
|
||||
case typeFloat64:
|
||||
return readArrayData[float64](f, n)
|
||||
case typeBool:
|
||||
return readArrayData[bool](f, n)
|
||||
case typeString:
|
||||
return readArrayString(f, n)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
func (f *File) KeyValue(key string) KeyValue {
|
||||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||
key = f.KeyValue("general.architecture").String() + "." + key
|
||||
}
|
||||
|
||||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||
return kv.Key == key
|
||||
}); index >= 0 {
|
||||
return f.keyValues.values[index]
|
||||
}
|
||||
|
||||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||
if keyValue.Key == key {
|
||||
return keyValue
|
||||
}
|
||||
}
|
||||
|
||||
return KeyValue{}
|
||||
}
|
||||
|
||||
func (f *File) NumKeyValues() int {
|
||||
return int(f.keyValues.count)
|
||||
}
|
||||
|
||||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
return f.keyValues.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
|
||||
return TensorInfo{}
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
t := f.TensorInfo(name)
|
||||
if t.NumBytes() == 0 {
|
||||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
249
fs/gguf/gguf_test.go
Normal file
249
fs/gguf/gguf_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package gguf_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
)
|
||||
|
||||
func createBinFile(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(tb.TempDir(), "")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
kv := ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(8),
|
||||
"llama.embedding_length": uint32(3),
|
||||
"llama.attention.head_count": uint32(2),
|
||||
"llama.attention.head_count_kv": uint32(2),
|
||||
"llama.attention.key_length": uint32(3),
|
||||
"llama.rope.dimension_count": uint32(4),
|
||||
"llama.rope.freq_base": float32(10000.0),
|
||||
"llama.rope.freq_scale": float32(1.0),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
|
||||
"tokenizer.ggml.eos_token_id": uint32(0),
|
||||
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
|
||||
"tokenizer.ggml.tokens": []string{"hello", "world"},
|
||||
"tokenizer.ggml.scores": []float32{0, 1},
|
||||
}
|
||||
|
||||
tensors := []*ggml.Tensor{
|
||||
{
|
||||
Name: "token_embd.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{2, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
|
||||
},
|
||||
{
|
||||
Name: "output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 2},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
tensors = append(tensors, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
})
|
||||
}
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
f, err := gguf.Open(createBinFile(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if got := f.KeyValue("does.not.exist").Valid(); got {
|
||||
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
|
||||
}
|
||||
|
||||
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
|
||||
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if got.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("block_count").Uint(); got != 8 {
|
||||
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
for _, kv := range f.KeyValues() {
|
||||
if !kv.Valid() {
|
||||
t.Error("found invalid key-value pair:", kv)
|
||||
}
|
||||
|
||||
kvs = append(kvs, kv.Key)
|
||||
}
|
||||
|
||||
if len(kvs) != f.NumKeyValues() {
|
||||
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kvs, []string{
|
||||
"general.architecture",
|
||||
"llama.block_count",
|
||||
"llama.embedding_length",
|
||||
"llama.attention.head_count",
|
||||
"llama.attention.head_count_kv",
|
||||
"llama.attention.key_length",
|
||||
"llama.rope.dimension_count",
|
||||
"llama.rope.freq_base",
|
||||
"llama.rope.freq_scale",
|
||||
"llama.attention.layer_norm_rms_epsilon",
|
||||
"tokenizer.ggml.eos_token_id",
|
||||
"tokenizer.ggml.eos_token_ids",
|
||||
"tokenizer.ggml.tokens",
|
||||
"tokenizer.ggml.scores",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var tis []string
|
||||
for _, ti := range f.TensorInfos() {
|
||||
if !ti.Valid() {
|
||||
t.Error("found invalid tensor info:", ti)
|
||||
}
|
||||
|
||||
tis = append(tis, ti.Name)
|
||||
}
|
||||
|
||||
if len(tis) != f.NumTensors() {
|
||||
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tis, []string{
|
||||
"token_embd.weight",
|
||||
"output.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.1.attn_q.weight",
|
||||
"blk.1.attn_k.weight",
|
||||
"blk.1.attn_v.weight",
|
||||
"blk.1.attn_output.weight",
|
||||
"blk.2.attn_q.weight",
|
||||
"blk.2.attn_k.weight",
|
||||
"blk.2.attn_v.weight",
|
||||
"blk.2.attn_output.weight",
|
||||
"blk.3.attn_q.weight",
|
||||
"blk.3.attn_k.weight",
|
||||
"blk.3.attn_v.weight",
|
||||
"blk.3.attn_output.weight",
|
||||
"blk.4.attn_q.weight",
|
||||
"blk.4.attn_k.weight",
|
||||
"blk.4.attn_v.weight",
|
||||
"blk.4.attn_output.weight",
|
||||
"blk.5.attn_q.weight",
|
||||
"blk.5.attn_k.weight",
|
||||
"blk.5.attn_v.weight",
|
||||
"blk.5.attn_output.weight",
|
||||
"blk.6.attn_q.weight",
|
||||
"blk.6.attn_k.weight",
|
||||
"blk.6.attn_v.weight",
|
||||
"blk.6.attn_output.weight",
|
||||
"blk.7.attn_q.weight",
|
||||
"blk.7.attn_k.weight",
|
||||
"blk.7.attn_v.weight",
|
||||
"blk.7.attn_output.weight",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
ti, r, err := f.TensorReader("output.weight")
|
||||
if err != nil {
|
||||
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if ti.Name != "output.weight" {
|
||||
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
|
||||
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if ti.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := b.ReadFrom(r); err != nil {
|
||||
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if b.Len() != int(ti.NumBytes()) {
|
||||
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
p := createBinFile(b)
|
||||
for b.Loop() {
|
||||
f, err := gguf.Open(p)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
b.Errorf("got = %q, want %q", got, "llama")
|
||||
}
|
||||
|
||||
// Iterate through some tensors
|
||||
for range f.TensorInfos() {
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
90
fs/gguf/keyvalue.go
Normal file
90
fs/gguf/keyvalue.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type KeyValue struct {
|
||||
Key string
|
||||
Value
|
||||
}
|
||||
|
||||
func (kv KeyValue) Valid() bool {
|
||||
return kv.Key != "" && kv.Value.value != nil
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
func (v Value) Int() int64 {
|
||||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||
func (v Value) Ints() (i64s []int64) {
|
||||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||
func (v Value) Uint() uint64 {
|
||||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||
func (v Value) Uints() (u64s []uint64) {
|
||||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||
func (v Value) Float() float64 {
|
||||
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||
func (v Value) Floats() (f64s []float64) {
|
||||
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||
func (v Value) Bool() bool {
|
||||
return value[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||
func (v Value) Bools() (bools []bool) {
|
||||
return values[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||
func (v Value) String() string {
|
||||
return value[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
208
fs/gguf/keyvalue_test.go
Normal file
208
fs/gguf/keyvalue_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||
for key, value := range values {
|
||||
if key == name {
|
||||
matched = value
|
||||
} else {
|
||||
unmatched = append(unmatched, value...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||
"float64": {float32(42), float64(42)},
|
||||
"string": {"42", "hello"},
|
||||
"bool": {true, false},
|
||||
}
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
matched, unmatched := split("int64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 42 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 0 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 42 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 0 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
matched, unmatched := split("float64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 42 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 0 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
matched, unmatched := split("string", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != v {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != "" {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
matched, unmatched := split("bool", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != v {
|
||||
t.Errorf("expected true, got %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != false {
|
||||
t.Errorf("expected false, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||
"float64s": {[]float32{42}, []float64{42}},
|
||||
"strings": {[]string{"42"}, []string{"hello"}},
|
||||
"bools": {[]bool{true}, []bool{false}},
|
||||
}
|
||||
|
||||
t.Run("int64s", func(t *testing.T) {
|
||||
matched, unmatched := split("int64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64s := kv.Ints(); i64s != nil {
|
||||
t.Errorf("expected nil, got %v", i64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64s", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64s := kv.Uints(); u64s != nil {
|
||||
t.Errorf("expected nil, got %v", u64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64s", func(t *testing.T) {
|
||||
matched, unmatched := split("float64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64s := kv.Floats(); f64s != nil {
|
||||
t.Errorf("expected nil, got %v", f64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
matched, unmatched := split("strings", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.Strings(); s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bools", func(t *testing.T) {
|
||||
matched, unmatched := split("bools", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bools(); b != nil {
|
||||
t.Errorf("expected nil, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
89
fs/gguf/lazy.go
Normal file
89
fs/gguf/lazy.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type lazy[T any] struct {
|
||||
count uint64
|
||||
next func() (T, bool)
|
||||
stop func()
|
||||
values []T
|
||||
|
||||
// successFunc is called when all values have been successfully read.
|
||||
successFunc func() error
|
||||
}
|
||||
|
||||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||
it := lazy[T]{}
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.values = make([]T, 0)
|
||||
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||
for i := range it.count {
|
||||
t, err := fn()
|
||||
if err != nil {
|
||||
slog.Error("error reading tensor", "index", i, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
it.values = append(it.values, t)
|
||||
if !yield(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if it.successFunc != nil {
|
||||
it.successFunc()
|
||||
}
|
||||
})
|
||||
|
||||
return &it, nil
|
||||
}
|
||||
|
||||
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for _, v := range g.All() {
|
||||
if !yield(v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := range int(g.count) {
|
||||
if i < len(g.values) {
|
||||
if !yield(i, g.values[i]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t, ok := g.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if !yield(i, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) rest() (collected bool) {
|
||||
for {
|
||||
_, ok := g.next()
|
||||
collected = collected || ok
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return collected
|
||||
}
|
||||
23
fs/gguf/reader.go
Normal file
23
fs/gguf/reader.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type bufferedReader struct {
|
||||
offset int64
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
|
||||
return &bufferedReader{
|
||||
Reader: bufio.NewReaderSize(rs, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = rs.Reader.Read(p)
|
||||
rs.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
288
fs/gguf/tensor.go
Normal file
288
fs/gguf/tensor.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Offset uint64
|
||||
Shape []uint64
|
||||
Type TensorType
|
||||
}
|
||||
|
||||
func (ti TensorInfo) Valid() bool {
|
||||
return ti.Name != "" && ti.NumBytes() > 0
|
||||
}
|
||||
|
||||
func (ti TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range ti.Shape {
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (ti TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (ti TensorInfo) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", ti.Name),
|
||||
slog.Int64("offset", int64(ti.Offset)),
|
||||
slog.Any("shape", ti.Shape),
|
||||
slog.Int64("num_values", ti.NumValues()),
|
||||
slog.Int64("num_bytes", ti.NumBytes()),
|
||||
slog.Any("type", ti.Type),
|
||||
)
|
||||
}
|
||||
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3
|
||||
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ2_XXS
|
||||
tensorTypeIQ2_XS
|
||||
tensorTypeIQ3_XXS
|
||||
tensorTypeIQ1_S
|
||||
tensorTypeIQ4_NL
|
||||
tensorTypeIQ3_S
|
||||
tensorTypeIQ2_S
|
||||
tensorTypeIQ4_XS
|
||||
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ1_M
|
||||
|
||||
TensorTypeBF16
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_0_4_4
|
||||
tensorTypeQ4_0_4_8
|
||||
tensorTypeQ4_0_8_8
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeTQ1_0
|
||||
tensorTypeTQ2_0
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
)
|
||||
|
||||
func (tt TensorType) NumBytes() float64 {
|
||||
return float64(tt.typeSize()) / float64(tt.blockSize())
|
||||
}
|
||||
|
||||
func (tt TensorType) typeSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + tt.blockSize()/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + tt.blockSize()/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + tt.blockSize()
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + tt.blockSize()
|
||||
case TensorTypeQ2_K:
|
||||
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/2
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
|
||||
case TensorTypeQ6_K:
|
||||
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + tt.blockSize() + 2*tt.blockSize()/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*tt.blockSize()/8
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + tt.blockSize()/8 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + tt.blockSize()/2
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) blockSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) String() string {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Uint64("value", uint64(tt)),
|
||||
slog.String("name", strings.ToUpper(tt.String())),
|
||||
slog.Int64("size", tt.typeSize()),
|
||||
slog.Int64("block_size", tt.blockSize()),
|
||||
slog.Float64("num_bytes", tt.NumBytes()),
|
||||
)
|
||||
}
|
||||
4
go.mod
4
go.mod
@@ -19,12 +19,13 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -44,7 +45,6 @@ require (
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
|
||||
@@ -45,6 +45,8 @@ var (
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
"codellama:latest",
|
||||
"nous-hermes:latest",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
|
||||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 3a4e72a3..831b68c0 100644
|
||||
index 3a4e72a3..db62973f 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
@@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
|
||||
4 files changed, 59 insertions(+), 79 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index c22687e4..c5948e8f 100644
|
||||
index dca22d8b..1f3a3956 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
@@ -41,7 +41,7 @@ index c22687e4..c5948e8f 100644
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
|
||||
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
|
||||
@@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
|
||||
3 files changed, 192 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index becdae07..7a44b6cf 100644
|
||||
index 955fec59..654e2f28 100644
|
||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ index becdae07..7a44b6cf 100644
|
||||
void ggml_compute_forward_argsort(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
|
||||
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
|
||||
{
|
||||
ggml_compute_forward_argsort_f32(params, dst);
|
||||
} break;
|
||||
@@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
|
||||
+ }
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||
index 2d46176e..47383486 100644
|
||||
index d027271f..4abd01d7 100644
|
||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
@@ -257,7 +257,7 @@ index 2d46176e..47383486 100644
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
@@ -266,7 +266,7 @@ index 2d46176e..47383486 100644
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Sun, 22 Jun 2025 09:22:05 -0700
|
||||
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
|
||||
1 file changed, 10 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 4e67d243..8f49f084 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
||||
- ggml_backend_load_best("hip", silent, dir_path);
|
||||
+
|
||||
+ // Avoid mixed hip+cuda configurations
|
||||
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
+ if (!hip_devices && !rocr_devices) {
|
||||
+ ggml_backend_load_best("cuda", silent, dir_path);
|
||||
+ } else {
|
||||
+ ggml_backend_load_best("hip", silent, dir_path);
|
||||
+ }
|
||||
+
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
169
llama/patches/0019-metal-add-mean-kernel-14267.patch
Normal file
169
llama/patches/0019-metal-add-mean-kernel-14267.patch
Normal file
@@ -0,0 +1,169 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Georgi Gerganov <ggerganov@gmail.com>
|
||||
Date: Thu, 19 Jun 2025 08:05:21 +0300
|
||||
Subject: [PATCH] metal : add mean kernel (#14267)
|
||||
|
||||
* metal : add mean kernel
|
||||
|
||||
ggml-ci
|
||||
|
||||
* cont : dedup implementation
|
||||
|
||||
ggml-ci
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
|
||||
2 files changed, 67 insertions(+), 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index ee4f2dcb..f20f5615 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ id<MTLComputePipelineState> pipeline = nil;
|
||||
+
|
||||
+ switch (dst->op) {
|
||||
+ case GGML_OP_SUM_ROWS:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ break;
|
||||
+ case GGML_OP_MEAN:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
+ break;
|
||||
+ default:
|
||||
+ GGML_ABORT("fatal error");
|
||||
+ }
|
||||
+
|
||||
+ int nth = 32; // SIMD width
|
||||
+
|
||||
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
+ nth *= 2;
|
||||
+ }
|
||||
|
||||
+ nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 9cfddf45..08e8d807 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
+template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
+ constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
- constant ggml_metal_kargs_sum_rows & args,
|
||||
- uint3 tpig[[thread_position_in_grid]]) {
|
||||
- int64_t i3 = tpig.z;
|
||||
- int64_t i2 = tpig.y;
|
||||
- int64_t i1 = tpig.x;
|
||||
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
+ ushort tiisg[[thread_index_in_simdgroup]],
|
||||
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
+ int64_t i3 = tgpig.z;
|
||||
+ int64_t i2 = tgpig.y;
|
||||
+ int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
+ if (sgitg == 0) {
|
||||
+ shmem_f32[tiisg] = 0.0f;
|
||||
+ }
|
||||
+
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
- float row_sum = 0;
|
||||
+ float sumf = 0;
|
||||
|
||||
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
- row_sum += src_row[i0];
|
||||
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
+ sumf += src_row[i0];
|
||||
}
|
||||
|
||||
- dst_row[0] = row_sum;
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ if (tiisg == 0) {
|
||||
+ shmem_f32[sgitg] = sumf;
|
||||
+ }
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ sumf = shmem_f32[tiisg];
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ if (tpitg.x == 0) {
|
||||
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
+ }
|
||||
}
|
||||
|
||||
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
+
|
||||
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
+
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
5089
llama/patches/0020-CUDA-add-mean-operation-14313.patch
Normal file
5089
llama/patches/0020-CUDA-add-mean-operation-14313.patch
Normal file
File diff suppressed because it is too large
Load Diff
@@ -151,7 +151,12 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
headsKV := f.KV().HeadCountKVMin()
|
||||
if headsKV == 0 {
|
||||
headsKV = 1
|
||||
}
|
||||
gqa := f.KV().HeadCountMax() / headsKV
|
||||
graphPartialOffload = gqa * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
|
||||
@@ -139,6 +139,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
// Verify the requested context size is <= the model training size
|
||||
trainCtx := f.KV().ContextLength()
|
||||
if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 {
|
||||
slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx)
|
||||
opts.NumCtx = int(trainCtx) * numParallel
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
@@ -311,7 +318,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
// without any LD_LIBRARY_PATH flags
|
||||
for {
|
||||
|
||||
@@ -253,6 +253,7 @@ type Tensor interface {
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
@@ -276,6 +277,7 @@ type Tensor interface {
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
@@ -297,6 +299,12 @@ type Tensor interface {
|
||||
|
||||
TopK(ctx Context, k int) Tensor
|
||||
Argsort(ctx Context) Tensor
|
||||
Mean(ctx Context) Tensor
|
||||
Variance(ctx Context) Tensor
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
|
||||
@@ -138,7 +138,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
|
||||
requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
||||
|
||||
// Bug #11211: Reporting of UUIDs is temporarily disabled due to causing segfaults
|
||||
// This only affects debug information until the new memory management code is in place
|
||||
// requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||
|
||||
@@ -155,7 +158,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(d, &props)
|
||||
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
||||
// requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
@@ -297,7 +300,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
case contains(t.Name, "cls", "output", "output_norm",
|
||||
"altup_proj", "altup_unembd_proj",
|
||||
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
@@ -602,7 +607,9 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
|
||||
needSync := true
|
||||
@@ -891,6 +898,13 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
|
||||
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||
panic("invalid dimension")
|
||||
@@ -1198,6 +1212,13 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1273,3 +1294,42 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
|
||||
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
|
||||
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
|
||||
Sqr(ctx).
|
||||
SumRows(ctx).
|
||||
Scale(ctx, 1/float64(t.Dim(0)))
|
||||
}
|
||||
|
||||
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
|
||||
return t.Variance(ctx).Sqrt(ctx)
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
|
||||
// Avoid mixed hip+cuda configurations
|
||||
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
if (!hip_devices && !rocr_devices) {
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
} else {
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
}
|
||||
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
|
||||
20
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
20
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
@@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
@@ -2322,6 +2323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cuda_op_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
@@ -3211,6 +3215,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
|
||||
19
ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
vendored
Normal file
19
ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
#include "mean.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
3
ml/backend/ggml/ggml/src/ggml-cuda/mean.cuh
vendored
Normal file
3
ml/backend/ggml/ggml/src/ggml-cuda/mean.cuh
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
23
ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu
vendored
23
ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu
vendored
@@ -1,25 +1,9 @@
|
||||
#include "sumrows.cuh"
|
||||
|
||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col == 0) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -3434,31 +3434,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
33
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
33
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
|
||||
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
51
model/models/gemma3n/model.go
Normal file
51
model/models/gemma3n/model.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package gemma3n
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3n", New)
|
||||
}
|
||||
360
model/models/gemma3n/model_text.go
Normal file
360
model/models/gemma3n/model_text.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package gemma3n
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
|
||||
|
||||
*PerLayerProjector
|
||||
|
||||
AltupEmbd *nn.Linear `gguf:"altup_proj"`
|
||||
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
|
||||
|
||||
TextLayers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
// Create a tensor of a single float32 value of 1.0 to use for altup correction
|
||||
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
|
||||
|
||||
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
|
||||
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
|
||||
|
||||
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
|
||||
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
|
||||
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates := inputs.Concat(ctx, altupProj, 2)
|
||||
|
||||
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
|
||||
for i, layer := range m.TextLayers {
|
||||
if i < firstSharedKeyValue {
|
||||
cache.SetLayer(i)
|
||||
} else if m.isLocal(i) {
|
||||
cache.SetLayer(firstSharedKeyValue - 2)
|
||||
} else {
|
||||
cache.SetLayer(firstSharedKeyValue - 1)
|
||||
}
|
||||
|
||||
var layerType int
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(i) {
|
||||
layerType = 1
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
||||
|
||||
// inputPerLayer = inputsPerLayer[:, i, :]
|
||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
||||
}
|
||||
|
||||
// hiddenStates = hiddenStates[:, :, 0]
|
||||
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
|
||||
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
// hiddenState = hiddenStates[:, :, 1:]
|
||||
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
|
||||
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
||||
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(layer) {
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
*nn.Embedding
|
||||
}
|
||||
|
||||
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
|
||||
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
|
||||
}
|
||||
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
*AltUp
|
||||
*Laurel
|
||||
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *TextAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
}
|
||||
|
||||
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
|
||||
predictions := d.Predict(ctx, hiddenStates, opts)
|
||||
active := opts.altupActive(ctx, predictions)
|
||||
|
||||
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
|
||||
laurel := d.Laurel.Forward(ctx, attn, opts)
|
||||
|
||||
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
|
||||
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
|
||||
attn = active.Add(ctx, attn)
|
||||
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
|
||||
|
||||
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
|
||||
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
|
||||
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
|
||||
mlp = attn.Add(ctx, mlp)
|
||||
|
||||
predictions = d.Correct(ctx, predictions, mlp, one, opts)
|
||||
active = opts.altupActive(ctx, predictions)
|
||||
if opts.altupCorrectScale {
|
||||
active = d.ScaleCorrectedOutput(ctx, active)
|
||||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
|
||||
// inactive := predictions[:, :, 1:]
|
||||
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
|
||||
active = inactive.Add(ctx, active)
|
||||
|
||||
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
|
||||
return predictions0.Concat(ctx, active, 2)
|
||||
}
|
||||
|
||||
type AltUp struct {
|
||||
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
|
||||
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
|
||||
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
|
||||
Router *nn.Linear `gguf:"altup_router"`
|
||||
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
|
||||
}
|
||||
|
||||
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
|
||||
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
|
||||
}
|
||||
|
||||
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
|
||||
|
||||
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
predictions := coefficients.Mulmat(ctx, hiddenStates)
|
||||
predictions = predictions.Add(ctx, hiddenStates)
|
||||
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
|
||||
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
|
||||
|
||||
modalities := a.computeRouterModalities(ctx, activated, opts)
|
||||
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Add(ctx, one)
|
||||
|
||||
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
|
||||
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
corrected := innovation.Mul(ctx, coefficients)
|
||||
corrected = corrected.Add(ctx, predictions)
|
||||
return corrected
|
||||
}
|
||||
|
||||
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
|
||||
return predictions.Mul(ctx, a.CorrectionScale)
|
||||
}
|
||||
|
||||
type Laurel struct {
|
||||
LinearLeft *nn.Linear `gguf:"laurel_l"`
|
||||
LinearRight *nn.Linear `gguf:"laurel_r"`
|
||||
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
|
||||
}
|
||||
|
||||
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
|
||||
upStates := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
|
||||
if activationSparsityScale > 0 {
|
||||
mean := hiddenStates.Mean(ctx)
|
||||
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
|
||||
cutoff := mean.Add(ctx, std)
|
||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenLayers int
|
||||
hiddenSize int
|
||||
hiddenSizePerLayerInput int
|
||||
numHeads, numKVHeads int
|
||||
keyLength, valueLength int
|
||||
sharedKeyValueLayers int
|
||||
|
||||
altupActiveIndex int
|
||||
altupInputs int
|
||||
altupCorrectScale bool
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeBaseLocal float32
|
||||
ropeScale float32
|
||||
|
||||
slidingWindowPattern []bool
|
||||
activationSparsityScale []float32
|
||||
}
|
||||
|
||||
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
// t[:, :, o.altupActiveIndex]
|
||||
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(i int) bool {
|
||||
return o.slidingWindowPattern[i]
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||
TextOptions: TextOptions{
|
||||
hiddenLayers: int(c.Uint("block_count")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
|
||||
|
||||
altupActiveIndex: int(c.Uint("altup.active_idx")),
|
||||
altupInputs: int(c.Uint("altup.num_inputs")),
|
||||
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
|
||||
@@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
@@ -66,9 +66,9 @@ type MLP interface {
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
@@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
|
||||
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.SILU(ctx)
|
||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||
|
||||
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
|
||||
@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
|
||||
16
model/vocabulary_test.go
Normal file
16
model/vocabulary_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||
}
|
||||
|
||||
specialVocab := vocab.SpecialVocabulary()
|
||||
|
||||
if len(specialVocab) != 4 {
|
||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||
}
|
||||
}
|
||||
@@ -292,13 +292,18 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
files = append(files, js...)
|
||||
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
// only include tokenizer.model is tokenizer.json is not present
|
||||
if !slices.ContainsFunc(files, func(s string) bool {
|
||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
||||
}) {
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
|
||||
@@ -27,7 +27,6 @@ function checkEnv() {
|
||||
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
|
||||
}
|
||||
# Locate CUDA versions
|
||||
# Note: this assumes every version found will be built
|
||||
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
|
||||
if ($cudaList.length -eq 0) {
|
||||
$d=(get-command -ea 'silentlycontinue' nvcc).path
|
||||
@@ -94,19 +93,6 @@ function buildOllama() {
|
||||
|
||||
$hashEnv = @{}
|
||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||
if ("$script:CUDA_DIRS".Contains("v11")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
|
||||
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
|
||||
write-host "Building CUDA v11 backend libraries"
|
||||
# Note: cuda v11 requires msvc 2019 so force the older generator
|
||||
# to avoid 2022 (or newer) from being used as the default
|
||||
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
if ("$script:CUDA_DIRS".Contains("v12")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
||||
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
||||
@@ -127,12 +113,17 @@ function buildOllama() {
|
||||
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
|
||||
$env:HIP_PLATFORM="amd"
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
& cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR
|
||||
& cmake --fresh --preset "ROCm 6" -G Ninja `
|
||||
-DCMAKE_C_COMPILER=clang `
|
||||
-DCMAKE_CXX_COMPILER=clang++ `
|
||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
--install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
$env:HIPCXX=""
|
||||
$env:HIP_PLATFORM=""
|
||||
$env:CMAKE_PREFIX_PATH=""
|
||||
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
|
||||
& cmake --build --preset "ROCm 6" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
|
||||
@@ -10,9 +10,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
|
||||
--build-arg=CUDA_V11_ARCHITECTURES \
|
||||
--build-arg=CUDA_V12_ARCHITECTURES \
|
||||
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
|
||||
--build-arg=OLLAMA_FAST_BUILD \
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
@@ -73,22 +73,18 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
r, err := os.Open(m.ModelPath)
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer r.Close()
|
||||
defer f.Close()
|
||||
|
||||
f, err := ggml.Decode(r, 1024)
|
||||
if err == nil {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
|
||||
@@ -1,123 +1,42 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Constants for GGUF magic bytes and version
|
||||
var (
|
||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||
ggufVer = uint32(3) // Version 3
|
||||
)
|
||||
|
||||
// Helper function to create mock GGUF data
|
||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Write GGUF header
|
||||
buf.Write(ggufMagic)
|
||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||
|
||||
// Write tensor count (0 for our test)
|
||||
var numTensors uint64 = 0
|
||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||
|
||||
// Calculate number of metadata entries
|
||||
numMetaEntries := uint64(1) // architecture entry
|
||||
if vision {
|
||||
numMetaEntries++
|
||||
}
|
||||
// Add embedding entry if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
numMetaEntries++
|
||||
}
|
||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||
|
||||
// Write architecture metadata
|
||||
archKey := "general.architecture"
|
||||
keyLen := uint64(len(archKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(archKey)
|
||||
|
||||
// String type (8)
|
||||
var strType uint32 = 8
|
||||
binary.Write(&buf, binary.LittleEndian, strType)
|
||||
|
||||
// String length
|
||||
strLen := uint64(len(architecture))
|
||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||
buf.WriteString(architecture)
|
||||
|
||||
if vision {
|
||||
visionKey := architecture + ".vision.block_count"
|
||||
keyLen = uint64(len(visionKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(visionKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var countVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||
}
|
||||
// Write embedding metadata if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
poolKey := architecture + ".pooling_type"
|
||||
keyLen = uint64(len(poolKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(poolKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var poolingVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create different types of mock model files
|
||||
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
if err := errors.Join(
|
||||
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
@@ -145,21 +64,13 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools and insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
@@ -224,29 +135,33 @@ func TestModelCapabilities(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
if err := errors.Join(
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
@@ -261,7 +176,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "completion model without tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools},
|
||||
@@ -270,7 +185,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model with all needed capabilities",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
@@ -278,7 +193,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model missing insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||
@@ -287,7 +202,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model missing vision capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
@@ -312,7 +227,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "unknown capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
|
||||
2
server/internal/cache/blob/cache.go
vendored
2
server/internal/cache/blob/cache.go
vendored
@@ -59,7 +59,7 @@ type DiskCache struct {
|
||||
testHookBeforeFinalWrite func(f *os.File)
|
||||
}
|
||||
|
||||
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
|
||||
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
|
||||
}
|
||||
|
||||
@@ -231,6 +231,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
|
||||
|
||||
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
|
||||
@@ -257,16 +257,8 @@ func TestQuantizeModel(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := os.CreateTemp(t.TempDir(), tt.name)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer f.Close()
|
||||
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create initial model: %s", err)
|
||||
}
|
||||
fp, err := os.Open(f.Name())
|
||||
p, _ := createBinFile(t, tt.kv, tt.tensors)
|
||||
fp, err := os.Open(p)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
@@ -929,7 +929,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
r := api.ListModelResponse{
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
@@ -942,16 +943,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
ParameterSize: cf.ModelType,
|
||||
QuantizationLevel: cf.FileType,
|
||||
},
|
||||
}
|
||||
|
||||
model, err := GetModel(n.String())
|
||||
if err != nil {
|
||||
slog.Warn("bad model details", "name", n, "error", err)
|
||||
} else {
|
||||
r.Capabilities = model.Capabilities()
|
||||
}
|
||||
|
||||
models = append(models, r)
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||
@@ -1534,12 +1526,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
@@ -1592,6 +1579,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
|
||||
@@ -191,7 +191,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
if err != nil {
|
||||
pending.errCh <- err
|
||||
break
|
||||
|
||||
@@ -112,11 +112,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
require.NoError(t, ggml.WriteGGUF(f, ggml.KV{
|
||||
p, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.context_length": uint32(32),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
@@ -129,14 +125,14 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
fname := f.Name()
|
||||
model := &Model{Name: modelName, ModelPath: fname}
|
||||
b.f, err = llm.LoadModel(model.ModelPath, 0)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
model := &Model{Name: modelName, ModelPath: p}
|
||||
f, err := llm.LoadModel(model.ModelPath, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b.f = f
|
||||
if duration == nil {
|
||||
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
}
|
||||
|
||||
156
tools/template.go
Normal file
156
tools/template.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
)
|
||||
|
||||
// parseTag finds the tool calling tag from a Go template
|
||||
// often <tool_call> [TOOL_CALL] or similar by finding the
|
||||
// first text node after .ToolCalls and returning the content
|
||||
// if no tag is found, return "{" to indicate that json objects
|
||||
// should be attempted to be parsed as tool calls
|
||||
func parseTag(tmpl *template.Template) string {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("template or tree is nil")
|
||||
return "{"
|
||||
}
|
||||
|
||||
tc := findToolCallNode(tmpl.Tree.Root.Nodes)
|
||||
if tc == nil {
|
||||
return "{"
|
||||
}
|
||||
|
||||
tn := findTextNode(tc.List.Nodes)
|
||||
if tn == nil {
|
||||
return "{"
|
||||
}
|
||||
|
||||
tag := string(tn.Text)
|
||||
tag = strings.ReplaceAll(tag, "\r\n", "\n")
|
||||
|
||||
// avoid parsing { onwards as this may be a tool call
|
||||
// however keep '{' as a prefix if there is no tag
|
||||
// so that all json objects will be attempted to
|
||||
// be parsed as tool calls
|
||||
tag, _, _ = strings.Cut(tag, "{")
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag == "" {
|
||||
tag = "{"
|
||||
}
|
||||
|
||||
return tag
|
||||
}
|
||||
|
||||
// findToolCallNode searches for and returns an IfNode with .ToolCalls
|
||||
func findToolCallNode(nodes []parse.Node) *parse.IfNode {
|
||||
isToolCallsNode := func(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
if slices.Contains(field.Ident, "ToolCalls") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if isToolCallsNode(n) {
|
||||
return n
|
||||
}
|
||||
// Recursively search in nested IfNodes
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case *parse.ListNode:
|
||||
if result := findToolCallNode(n.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
case *parse.RangeNode:
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case *parse.WithNode:
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findTextNode does a depth-first search for the first text content in nodes,
|
||||
// stopping at template constructs to avoid parsing text after the tool calls
|
||||
func findTextNode(nodes []parse.Node) *parse.TextNode {
|
||||
for _, node := range nodes {
|
||||
switch n := node.(type) {
|
||||
case *parse.TextNode:
|
||||
// skip whitespace-only text nodes
|
||||
if len(bytes.TrimSpace(n.Text)) == 0 {
|
||||
continue
|
||||
}
|
||||
return n
|
||||
case *parse.IfNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.ListNode:
|
||||
if text := findTextNode(n.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
case *parse.RangeNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.WithNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.ActionNode:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
139
tools/template_test.go
Normal file
139
tools/template_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestParseTag(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
template: "",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "no tag",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "no tag with range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json",
|
||||
},
|
||||
{
|
||||
name: "square brackets",
|
||||
template: "{{if .ToolCalls}}[{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "[",
|
||||
},
|
||||
{
|
||||
name: "square brackets with whitespace",
|
||||
template: "{{if .ToolCalls}}\n [ {{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "[",
|
||||
},
|
||||
{
|
||||
name: "tailing ]",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
template: "{{if .ToolCalls}} {{range .ToolCalls}}{{ . }}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "whitespace only in range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{{ . }}\n{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`,
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with whitespace",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with CRLF",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\r\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with whitespace before and after range",
|
||||
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\r\n{{end}}\r\n{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "before and after range",
|
||||
template: "{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>functionget_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>\n{{end}}<|tool▁calls▁end|>{{end}}",
|
||||
want: "<|tool▁calls▁begin|>",
|
||||
},
|
||||
{
|
||||
name: "after range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "after range with leading whitespace before range",
|
||||
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "tool call in range with {",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}<tool_call>{{end}}{{end}}`,
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
},
|
||||
{
|
||||
name: "action tag",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action: ```json",
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools[",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL] [",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL][",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got := parseTag(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
67
tools/testdata/command-r-plus.gotmpl
vendored
67
tools/testdata/command-r-plus.gotmpl
vendored
@@ -1,67 +0,0 @@
|
||||
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||
{{- if .Tools }}# Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
{{ if .System }}# User Preamble
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
{{- range .Tools }}
|
||||
|
||||
```python
|
||||
def {{ .Function.Name }}(
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
|
||||
"""{{ .Function.Description }}
|
||||
|
||||
{{- if .Function.Parameters.Properties }}
|
||||
|
||||
Args:
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}
|
||||
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
"""
|
||||
pass
|
||||
```
|
||||
{{- end }}
|
||||
{{- else if .System }}{{ .System }}
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "system" }}
|
||||
{{- continue }}
|
||||
{{- end }}<|START_OF_TURN_TOKEN|>
|
||||
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}
|
||||
Action: ```json
|
||||
[
|
||||
{{- range .ToolCalls }}
|
||||
{
|
||||
"tool_name": "{{ .Function.Name }}",
|
||||
"parameters": {{ .Function.Arguments }}
|
||||
}
|
||||
{{- end }}
|
||||
]```
|
||||
{{ continue }}
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
|
||||
{{ .Content }}</results>
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```
|
||||
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
39
tools/testdata/command-r-plus.out
vendored
39
tools/testdata/command-r-plus.out
vendored
@@ -1,39 +0,0 @@
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
# User Preamble
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
|
||||
```python
|
||||
def get_current_weather(format: string, location: string, ) -> List[Dict]:
|
||||
"""Get the current weather
|
||||
|
||||
Args:
|
||||
format (string): The temperature unit to use. Infer this from the user's location.
|
||||
location (string): The city and state, e.g. San Francisco, CA
|
||||
"""
|
||||
pass
|
||||
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
Action: ```json
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {"format":"celsius","location":"Paris, France"}
|
||||
}
|
||||
]```
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
|
||||
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
31
tools/testdata/firefunction.gotmpl
vendored
31
tools/testdata/firefunction.gotmpl
vendored
@@ -1,31 +0,0 @@
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>
|
||||
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
|
||||
{{- end }}<|end_header_id|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }} functools[
|
||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||
{{- end }}]
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
17
tools/testdata/firefunction.out
vendored
17
tools/testdata/firefunction.out
vendored
@@ -1,17 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgeable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
43
tools/testdata/llama3-groq-tool-use.gotmpl
vendored
43
tools/testdata/llama3-groq-tool-use.gotmpl
vendored
@@ -1,43 +0,0 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}
|
||||
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{{- range .Tools }} {{ .Function }}
|
||||
{{- end }} </tools>
|
||||
{{- end }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range .Messages }}
|
||||
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ if eq .Role "user" }}{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response>
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}{{ .Response }}
|
||||
{{- if .Response }}<|eot_id|>
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3-groq-tool-use.out
vendored
24
tools/testdata/llama3-groq-tool-use.out
vendored
@@ -1,24 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
44
tools/testdata/llama3.2.gotmpl
vendored
44
tools/testdata/llama3.2.gotmpl
vendored
@@ -1,44 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
{{ if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
||||
{{- if and $.Tools $last }}
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{{ range $.Tools }}
|
||||
{{- . }}
|
||||
{{ end }}
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}
|
||||
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3.2.out
vendored
24
tools/testdata/llama3.2.out
vendored
@@ -1,24 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
39
tools/testdata/messages.json
vendored
39
tools/testdata/messages.json
vendored
@@ -1,39 +0,0 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a knowledgeable assistant. You can answer questions and perform tasks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in Paris?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "Paris, France",
|
||||
"format": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"content": "22"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Paris, France is 22 degrees Celsius."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in San Francisco and Toronto?"
|
||||
}
|
||||
]
|
||||
15
tools/testdata/mistral.gotmpl
vendored
15
tools/testdata/mistral.gotmpl
vendored
@@ -1,15 +0,0 @@
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}
|
||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||
|
||||
{{ end }}{{ .Content }}[/INST]
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }} {{ .Content }}</s>
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}]</s>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
3
tools/testdata/mistral.out
vendored
3
tools/testdata/mistral.out
vendored
@@ -1,3 +0,0 @@
|
||||
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?[/INST]
|
||||
33
tools/testdata/nemotron.gotmpl
vendored
33
tools/testdata/nemotron.gotmpl
vendored
@@ -1,33 +0,0 @@
|
||||
{{- if (or .Tools .System) }}<extra_id_0>System
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
|
||||
{{ end }}
|
||||
{{- if .Tools }}
|
||||
{{- range .Tools }}<tool> {{ . }} </tool>{{ end }}
|
||||
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- range $i, $m := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<extra_id_1>User
|
||||
{{ .Content }}
|
||||
{{- if $last }}
|
||||
<extra_id_1>Assistant
|
||||
{{- end }}
|
||||
{{ else if eq .Role "tool" }}<extra_id_1>Tool
|
||||
{{ .Content }}
|
||||
{{- if $last }}
|
||||
<extra_id_1>Assistant
|
||||
{{- end }}
|
||||
{{ else if eq .Role "assistant" }}<extra_id_1>Assistant
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}<toolcall> {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} </toolcall> {{ end }}
|
||||
{{ else }}
|
||||
{{ .Content }}
|
||||
{{- if not $last }}
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
18
tools/testdata/nemotron.out
vendored
18
tools/testdata/nemotron.out
vendored
@@ -1,18 +0,0 @@
|
||||
<extra_id_0>System
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
|
||||
<tool> {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} </tool>
|
||||
|
||||
|
||||
<extra_id_1>User
|
||||
What's the weather like today in Paris?
|
||||
<extra_id_1>Assistant
|
||||
<toolcall> {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} </toolcall>
|
||||
<extra_id_1>Tool
|
||||
22
|
||||
<extra_id_1>Assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.
|
||||
<extra_id_1>User
|
||||
What's the weather like today in San Francisco and Toronto?
|
||||
<extra_id_1>Assistant
|
||||
51
tools/testdata/qwen2.5.gotmpl
vendored
51
tools/testdata/qwen2.5.gotmpl
vendored
@@ -1,51 +0,0 @@
|
||||
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
||||
{{- else if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen2.5.out
vendored
31
tools/testdata/qwen2.5.out
vendored
@@ -1,31 +0,0 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
50
tools/testdata/qwen3.gotmpl
vendored
50
tools/testdata/qwen3.gotmpl
vendored
@@ -1,50 +0,0 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen3.out
vendored
31
tools/testdata/qwen3.out
vendored
@@ -1,31 +0,0 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
30
tools/testdata/tools.json
vendored
30
tools/testdata/tools.json
vendored
@@ -1,30 +0,0 @@
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
],
|
||||
"description": "The temperature unit to use. Infer this from the user's location."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location",
|
||||
"format"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
45
tools/testdata/xlam.gotmpl
vendored
45
tools/testdata/xlam.gotmpl
vendored
@@ -1,45 +0,0 @@
|
||||
{{- if .System }}{{ .System }}
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction:
|
||||
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
{{ $.Tools }}
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
{{ .Content }}
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
{{ else }}
|
||||
{{ .Content }}
|
||||
{{ end }}
|
||||
{{- else if .ToolCalls }}### Response:
|
||||
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
|
||||
<|EOT|>
|
||||
{{ else if eq .Role "assistant" }}### Response:
|
||||
{{ .Content }}
|
||||
<|EOT|>
|
||||
{{ end }}
|
||||
{{- end }}### Response:
|
||||
40
tools/testdata/xlam.out
vendored
40
tools/testdata/xlam.out
vendored
@@ -1,40 +0,0 @@
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
### Instruction:
|
||||
What's the weather like today in Paris?
|
||||
### Response:
|
||||
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
|
||||
<|EOT|>
|
||||
### Response:
|
||||
The current temperature in Paris, France is 22 degrees Celsius.
|
||||
<|EOT|>
|
||||
### Instruction:
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
What's the weather like today in San Francisco and Toronto?
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
### Response:
|
||||
521
tools/tools.go
521
tools/tools.go
@@ -1,253 +1,304 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToolCall = errors.New("invalid tool call format")
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
type toolsState int
|
||||
|
||||
const (
|
||||
toolsState_LookingForTag toolsState = iota
|
||||
toolsState_ToolCalling
|
||||
toolsState_Done
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
tag string
|
||||
tools []api.Tool
|
||||
|
||||
state toolsState
|
||||
buffer []byte
|
||||
n int
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to parse
|
||||
// - name: The field name from template that identifies the tool call name
|
||||
// - arguments: The field name from template that identifies the tool call arguments
|
||||
//
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||
// Check for balanced braces before attempting to parse
|
||||
braceCount := 0
|
||||
squareCount := 0
|
||||
startIndex := -1
|
||||
var rawToolCalls []string
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
case '[':
|
||||
if trackSquareBrackets {
|
||||
squareCount++
|
||||
}
|
||||
case ']':
|
||||
if trackSquareBrackets {
|
||||
squareCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Negative means we have an extra closing brace/bracket
|
||||
if braceCount < 0 || squareCount < 0 {
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// If braces/brackets aren't balanced, need more input
|
||||
if braceCount > 0 || squareCount > 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
t := strings.TrimSpace(s)
|
||||
if len(t) == 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
// If the input is a single square bracket, it's not a valid tool call
|
||||
if t[0] == '[' && len(t) == 1 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var toolCalls []api.ToolCall
|
||||
for _, rawToolCall := range rawToolCalls {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect nested objects that could contain tool calls
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
// NewParser creates a new tool call parser from a model's chat
|
||||
// template and a list of provided tools.
|
||||
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
|
||||
return NewParserWithTag(tools, parseTag(tmpl))
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
if s == "" || p.prefix == "" {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
p.prefixFound = true
|
||||
return cut, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||
// Return everything before prefix
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParseJSON && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, s
|
||||
}
|
||||
|
||||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
}
|
||||
p.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.prefix != "" {
|
||||
p.greedyParseJSON = false
|
||||
}
|
||||
if p.index != 0 && p.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return toolCalls, ""
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParseJSON: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
tag: tag,
|
||||
tools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content that
|
||||
// should be sent back to the user.
|
||||
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
|
||||
if p.state == toolsState_Done {
|
||||
return nil, s
|
||||
}
|
||||
|
||||
p.buffer = append(p.buffer, s...)
|
||||
|
||||
if p.state == toolsState_LookingForTag {
|
||||
i, found := p.findTag()
|
||||
if i == -1 {
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
} else {
|
||||
content = string(p.buffer[:i])
|
||||
p.buffer = p.buffer[i:]
|
||||
}
|
||||
|
||||
// for models where { or [ are used as tool calling
|
||||
// tags, we only support parsing tools if the first non-
|
||||
// whitespace character is { or [
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
p.state = toolsState_Done
|
||||
return nil, content + string(p.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, content
|
||||
}
|
||||
|
||||
p.state = toolsState_ToolCalling
|
||||
}
|
||||
|
||||
for {
|
||||
call := p.parseToolCall()
|
||||
if call == nil {
|
||||
break
|
||||
}
|
||||
|
||||
calls = append(calls, *call)
|
||||
}
|
||||
|
||||
if p.done() {
|
||||
p.state = toolsState_Done
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
}
|
||||
|
||||
return calls, content
|
||||
}
|
||||
|
||||
// findTag searches the buffer to find and handle a tool calling tag
|
||||
// returning true if the tag was found and false otherwise, and
|
||||
// a string content signaling any content that should be sent back to the user
|
||||
func (p *Parser) findTag() (int, bool) {
|
||||
// First check for complete substring anywhere in s
|
||||
if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
|
||||
return i, true
|
||||
}
|
||||
|
||||
// Then check for partial suffix overlap
|
||||
max := min(len(p.buffer), len(p.tag))
|
||||
for i := max; i > 0; i-- {
|
||||
if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
|
||||
return len(p.buffer) - i, false
|
||||
}
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// parseToolCall finds the next complete tool call in the buffer
|
||||
// incrementing n and advancing the buffer.
|
||||
func (p *Parser) parseToolCall() *api.ToolCall {
|
||||
var tool *api.Tool
|
||||
var end int = len(p.buffer)
|
||||
var i int
|
||||
|
||||
// find tool name
|
||||
for _, t := range p.tools {
|
||||
n := t.Function.Name
|
||||
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
|
||||
if i+len(n) < end {
|
||||
tool = &t
|
||||
end = i + len(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tool == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only look for arguments after the tool name if the tool has parameters
|
||||
// TODO (jmorganca): while probably uncommon, this doesn't support
|
||||
// parsing arguments before the tool name, which may be needed in the future
|
||||
args := map[string]any{}
|
||||
if len(tool.Function.Parameters.Properties) > 0 {
|
||||
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
end += i
|
||||
}
|
||||
|
||||
tc := &api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: tool.Function.Name,
|
||||
Arguments: args,
|
||||
Index: p.n,
|
||||
},
|
||||
}
|
||||
|
||||
p.n++
|
||||
p.buffer = p.buffer[end:]
|
||||
return tc
|
||||
}
|
||||
|
||||
// findArguments returns the first object that appears to be
|
||||
// arguments for the provided tool in the provided buffer,
|
||||
// returning nil if no arguments are found.
|
||||
// TODO (jmorganca): this does not support parsing omitted arguments
|
||||
// objects for functions that have all-optional parameters
|
||||
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||
// `{"name": "get_conditions"}` will not currently work
|
||||
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
|
||||
if len(buffer) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var braces int
|
||||
var start int = -1
|
||||
var end int
|
||||
var object []byte
|
||||
|
||||
// find any outer json object
|
||||
for i, c := range buffer {
|
||||
if c == '{' {
|
||||
braces++
|
||||
if start == -1 {
|
||||
start = i
|
||||
}
|
||||
}
|
||||
|
||||
if c == '}' {
|
||||
if start != -1 {
|
||||
braces--
|
||||
if braces == 0 {
|
||||
end = i + 1
|
||||
object = buffer[start:end]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if braces > 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var find func(obj any) map[string]any
|
||||
find = func(obj any) map[string]any {
|
||||
switch obj := obj.(type) {
|
||||
case map[string]any:
|
||||
valid := true
|
||||
// check if all keys in the object exist in the tool's parameters
|
||||
for key := range obj {
|
||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// check for required parameters
|
||||
// TODO (jmorganca): this should error instead of silently failing
|
||||
if valid {
|
||||
for _, required := range tool.Function.Parameters.Required {
|
||||
if _, exists := obj[required]; !exists {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid {
|
||||
return obj
|
||||
}
|
||||
|
||||
for _, value := range obj {
|
||||
if result := find(value); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range obj {
|
||||
if result := find(item); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
result := find(data)
|
||||
if result != nil {
|
||||
return result, end
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// done checks if the parser is done parsing by looking
|
||||
// for closing tag. currently only } and ] are supported
|
||||
// for closing tags as {} or [] pairs may not always
|
||||
// represent tool calls and we need to send the content back
|
||||
func (p *Parser) done() bool {
|
||||
var open, close rune
|
||||
switch p.tag {
|
||||
case "{":
|
||||
open, close = '{', '}'
|
||||
case "[":
|
||||
open, close = '[', ']'
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, c := range p.buffer {
|
||||
if c == byte(open) {
|
||||
count++
|
||||
} else if c == byte(close) {
|
||||
count--
|
||||
if count == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Content returns any remaining content that
|
||||
// should be sent to the user. This should be the empty string
|
||||
// string unless the tag is { or [ and a tool call was not found
|
||||
func (p *Parser) Content() string {
|
||||
if p.n > 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
return string(p.buffer)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
1527
tools/tools_test.go
1527
tools/tools_test.go
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user