mirror of
https://github.com/ollama/ollama.git
synced 2025-12-31 11:39:00 -05:00
Compare commits
345 Commits
brucemacd/
...
pdevine/in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61349a8ec6 | ||
|
|
b97eb2b858 | ||
|
|
ad6f6a1d29 | ||
|
|
6723a40be6 | ||
|
|
3258a89b6e | ||
|
|
1c093e97af | ||
|
|
a8d9c2648e | ||
|
|
0334e67ffd | ||
|
|
e0ead1adee | ||
|
|
d515aed6c3 | ||
|
|
5fe7ba1b9b | ||
|
|
d2b63c19b3 | ||
|
|
94f110b35a | ||
|
|
5d22953ba7 | ||
|
|
d245dffed8 | ||
|
|
bc1a818fdc | ||
|
|
ba2253dc30 | ||
|
|
68e04c7ff8 | ||
|
|
270679932f | ||
|
|
65fb3ff49d | ||
|
|
e2a0b24435 | ||
|
|
1813ff85a0 | ||
|
|
b531777a66 | ||
|
|
fe3ec8dbf0 | ||
|
|
c744134287 | ||
|
|
4be41d2d45 | ||
|
|
de670570c9 | ||
|
|
201d93716e | ||
|
|
160cecc8e2 | ||
|
|
8b6e5baee7 | ||
|
|
75d17fc6c2 | ||
|
|
8fafc8af77 | ||
|
|
c3c85aa06c | ||
|
|
0d713051a2 | ||
|
|
c4c5a4a01e | ||
|
|
3dcfd5f69e | ||
|
|
53a969d509 | ||
|
|
08fbb60bb2 | ||
|
|
850da848c5 | ||
|
|
2aba569a2a | ||
|
|
fd8aa947f3 | ||
|
|
ddaca643d0 | ||
|
|
05982a95cb | ||
|
|
4987f13d34 | ||
|
|
e638f2acb6 | ||
|
|
18087f2ec7 | ||
|
|
6c833d5f8d | ||
|
|
6544e14735 | ||
|
|
5db8a818a1 | ||
|
|
6db8da9958 | ||
|
|
0c68ec8d6a | ||
|
|
70d9e363e1 | ||
|
|
1a2feb2a97 | ||
|
|
aab2190420 | ||
|
|
629db9dc43 | ||
|
|
e0cd511661 | ||
|
|
207332078f | ||
|
|
93085127f4 | ||
|
|
c00fa9cc2b | ||
|
|
df411c4b02 | ||
|
|
3d32249c74 | ||
|
|
d681cd7c29 | ||
|
|
47298fce39 | ||
|
|
4a48937ef1 | ||
|
|
967a82f52f | ||
|
|
bbbc73d637 | ||
|
|
15e3611d3d | ||
|
|
77060d462c | ||
|
|
1b91d4dda1 | ||
|
|
7d965258ce | ||
|
|
6a62b894c7 | ||
|
|
90d429f5a8 | ||
|
|
1fc35f1260 | ||
|
|
aa45f7ce27 | ||
|
|
4e5d862ec4 | ||
|
|
303be9304c | ||
|
|
bd15eba4e4 | ||
|
|
bc71278670 | ||
|
|
918231931c | ||
|
|
04c1849878 | ||
|
|
2c2f4deaa9 | ||
|
|
292767afb4 | ||
|
|
ae5e0f0889 | ||
|
|
19e6796eac | ||
|
|
33801c1597 | ||
|
|
e4340667e3 | ||
|
|
2fa1e92a99 | ||
|
|
07e36761c3 | ||
|
|
c29fb007c0 | ||
|
|
730ed6e9e1 | ||
|
|
dc06601677 | ||
|
|
1ed2881ef0 | ||
|
|
0bda72892c | ||
|
|
55ca827267 | ||
|
|
c68f367ef6 | ||
|
|
fdb109469f | ||
|
|
05a43e078a | ||
|
|
bc8909fb38 | ||
|
|
6b50f2b9cd | ||
|
|
35ac4eb12c | ||
|
|
3d0b1734c0 | ||
|
|
efaee8c2d6 | ||
|
|
734b57da0e | ||
|
|
83021fcf0f | ||
|
|
0469861d9d | ||
|
|
c47154c08d | ||
|
|
b04e46da3e | ||
|
|
34efbbd3f0 | ||
|
|
05ba4ca1f4 | ||
|
|
5a56ff3cf0 | ||
|
|
2fba04b5fb | ||
|
|
fbd82ba5bb | ||
|
|
2e742544bf | ||
|
|
bbb195a6ff | ||
|
|
fd88cd7cb0 | ||
|
|
e1979c571a | ||
|
|
bf78ed6ee9 | ||
|
|
a40d427bce | ||
|
|
64883e3c4c | ||
|
|
41efdd4048 | ||
|
|
c23e6f4cae | ||
|
|
af060eb250 | ||
|
|
ae5c33008e | ||
|
|
3677842ff1 | ||
|
|
242df70a75 | ||
|
|
dba39b2eee | ||
|
|
9f3a37fd36 | ||
|
|
7460259eb3 | ||
|
|
22ccdd74c2 | ||
|
|
0c3d0e7533 | ||
|
|
e7f56ef3d8 | ||
|
|
eb0a5d4459 | ||
|
|
ceac416ec2 | ||
|
|
2717dce6fe | ||
|
|
9b8187b487 | ||
|
|
8b894933a7 | ||
|
|
9c5bf342bc | ||
|
|
564b558c92 | ||
|
|
a417ac97ee | ||
|
|
05d53457af | ||
|
|
b225508c9b | ||
|
|
fa1c987a29 | ||
|
|
ad95d5b30b | ||
|
|
c253433d68 | ||
|
|
a1cff89b30 | ||
|
|
93c64ea1b1 | ||
|
|
3f6642f6fc | ||
|
|
6f7117145f | ||
|
|
472feec2ff | ||
|
|
47991940d4 | ||
|
|
92b96d54ef | ||
|
|
9d56e63dbf | ||
|
|
053092185e | ||
|
|
44a6792873 | ||
|
|
e4ce68311a | ||
|
|
26214125e8 | ||
|
|
61fb912ca4 | ||
|
|
aba1575315 | ||
|
|
eb10390de9 | ||
|
|
feb18cd710 | ||
|
|
8a7e2055d2 | ||
|
|
29ddfc2cab | ||
|
|
71cb86af3e | ||
|
|
5198956372 | ||
|
|
17a023f34b | ||
|
|
8d6fffaead | ||
|
|
20b53eaa72 | ||
|
|
6745182885 | ||
|
|
f810ec741c | ||
|
|
e119783e66 | ||
|
|
1a558f98e2 | ||
|
|
7b91c9ce51 | ||
|
|
950d33aa30 | ||
|
|
9714e38dd0 | ||
|
|
4378ae4ffa | ||
|
|
5994e8e8fd | ||
|
|
b3e6120736 | ||
|
|
fb92b61754 | ||
|
|
8149a3c86e | ||
|
|
0cc90a8186 | ||
|
|
e42300f25b | ||
|
|
66e73809a1 | ||
|
|
517807cdf2 | ||
|
|
ead4a9a1d0 | ||
|
|
4383a3ab7a | ||
|
|
9d97e6a9f1 | ||
|
|
1081532430 | ||
|
|
59412fbb43 | ||
|
|
86834a2797 | ||
|
|
85ccf7354d | ||
|
|
30fb7e19f8 | ||
|
|
d3450dd52e | ||
|
|
4bcb04ad88 | ||
|
|
e3d5708754 | ||
|
|
4be4dc8717 | ||
|
|
109d4fc3b4 | ||
|
|
2cb0a580f3 | ||
|
|
7cce5aac76 | ||
|
|
4ae4f47b16 | ||
|
|
073fa31df5 | ||
|
|
91fc3c48e3 | ||
|
|
6de62664d9 | ||
|
|
463a6caad8 | ||
|
|
fc5fb09f51 | ||
|
|
05ccb17c6e | ||
|
|
f804e8a460 | ||
|
|
9cfbffafc5 | ||
|
|
470d580205 | ||
|
|
b517bb1c19 | ||
|
|
e3ade453a8 | ||
|
|
048bd4472a | ||
|
|
ec8bf5e6c5 | ||
|
|
709bbb0b6d | ||
|
|
abeec240f9 | ||
|
|
df335aac09 | ||
|
|
026bc29237 | ||
|
|
883d031268 | ||
|
|
5271ff8559 | ||
|
|
d6f7233a1c | ||
|
|
8de1da4767 | ||
|
|
d925b5350c | ||
|
|
6eaf194b85 | ||
|
|
d5a0d8d904 | ||
|
|
ef7d26ba2c | ||
|
|
1a19df1f3a | ||
|
|
7ccfd97a93 | ||
|
|
c385ca8672 | ||
|
|
837379a94c | ||
|
|
a24f90604f | ||
|
|
dc5a645434 | ||
|
|
bb71654ebe | ||
|
|
a343ae53a4 | ||
|
|
d0cf6c8281 | ||
|
|
8f4ec9ab28 | ||
|
|
dbfd7bd027 | ||
|
|
ee04dbba51 | ||
|
|
ea7657b54a | ||
|
|
2c776f0780 | ||
|
|
79f6376f5b | ||
|
|
756c78cfc7 | ||
|
|
d7f4f788d1 | ||
|
|
114c3f2265 | ||
|
|
f2e9c9aff5 | ||
|
|
aa9d889522 | ||
|
|
735c41f9ca | ||
|
|
223a619468 | ||
|
|
759dd78dd6 | ||
|
|
44bc36d063 | ||
|
|
8f14e1f5f6 | ||
|
|
203c137810 | ||
|
|
fa8be9e35c | ||
|
|
8a75e9ee15 | ||
|
|
4742e12c23 | ||
|
|
2d06977ade | ||
|
|
30f8a68c4c | ||
|
|
e378e33421 | ||
|
|
fcec04bf42 | ||
|
|
ee92ca3e1d | ||
|
|
8253ad4d2b | ||
|
|
fa7776fd24 | ||
|
|
0d38b66502 | ||
|
|
4183bb0574 | ||
|
|
ff89ba90bc | ||
|
|
6dcc5dfb9c | ||
|
|
25911a6e6b | ||
|
|
8afa6e83f2 | ||
|
|
ea85e27bbd | ||
|
|
c116a7523d | ||
|
|
3515cc377c | ||
|
|
bbf66c0b96 | ||
|
|
764be7480f | ||
|
|
b72e5adb14 | ||
|
|
80b538e312 | ||
|
|
4f8a0166cc | ||
|
|
1e6eab5c33 | ||
|
|
6c733bf0a6 | ||
|
|
3bac5cba60 | ||
|
|
4151ef8cf7 | ||
|
|
82da19c634 | ||
|
|
bdd9d22dfd | ||
|
|
5fc38d042f | ||
|
|
191d94289d | ||
|
|
802ad16ce4 | ||
|
|
5e67f4f90e | ||
|
|
e840ccb523 | ||
|
|
b4fe3adc0a | ||
|
|
d73f8aa8c3 | ||
|
|
92c2e8a56c | ||
|
|
2e3fd86d48 | ||
|
|
4261a3b0b2 | ||
|
|
acef9b4c1b | ||
|
|
9a43994c45 | ||
|
|
f8a6e88819 | ||
|
|
35fda7b4af | ||
|
|
66fb8575ce | ||
|
|
20c3266e94 | ||
|
|
34088dbcfb | ||
|
|
43107b15b9 | ||
|
|
1f91cb0c8c | ||
|
|
12d8ad0d38 | ||
|
|
592d21e7db | ||
|
|
5a08b01f5b | ||
|
|
4f473e224c | ||
|
|
9d60bb44cf | ||
|
|
f371260e75 | ||
|
|
c9e6d7719e | ||
|
|
2c4ce40334 | ||
|
|
5d8c173529 | ||
|
|
44b17d2bfa | ||
|
|
3b8b692218 | ||
|
|
4129af9205 | ||
|
|
45f216a9c7 | ||
|
|
d0b32def60 | ||
|
|
11ffc36157 | ||
|
|
ba04902670 | ||
|
|
3944602f51 | ||
|
|
73b642e6f3 | ||
|
|
ad118d8b13 | ||
|
|
f08534137b | ||
|
|
4b4a90f233 | ||
|
|
03274a6b2f | ||
|
|
cc6463ebca | ||
|
|
405d2f628f | ||
|
|
a3f7dd3e98 | ||
|
|
c85c0ebf89 | ||
|
|
10a8e04a8d | ||
|
|
1c6669e64c | ||
|
|
b2b270ad5d | ||
|
|
2bb69b40c7 | ||
|
|
65bff664cb | ||
|
|
c088ac0e79 | ||
|
|
0a066cfd91 | ||
|
|
87b7af6cee | ||
|
|
f2527b08fb | ||
|
|
8bcb3125c1 | ||
|
|
6baf1e31e2 | ||
|
|
ed567ef43b | ||
|
|
a6e64fbdf2 | ||
|
|
60cfa2a203 | ||
|
|
55bbf3b4a1 | ||
|
|
6bda1d2479 | ||
|
|
9e125d884c | ||
|
|
20c5fd39c8 | ||
|
|
d2ee599dcf | ||
|
|
6ed8898590 |
266
.github/workflows/release.yaml
vendored
266
.github/workflows/release.yaml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-13
|
||||
runs-on: macos-13-xlarge
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
strategy:
|
||||
@@ -54,48 +54,6 @@ jobs:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: dist/*
|
||||
|
||||
darwin-sign:
|
||||
runs-on: macos-13
|
||||
environment: release
|
||||
needs: darwin-build
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: |
|
||||
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
|
||||
security create-keychain -p password build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p password build.keychain
|
||||
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
security set-keychain-settings -lut 3600 build.keychain
|
||||
env:
|
||||
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
|
||||
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-amd64
|
||||
path: dist/darwin-amd64
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-arm64
|
||||
path: dist/darwin-arm64
|
||||
- run: |
|
||||
export VERSION=${GITHUB_REF_NAME#v}
|
||||
./scripts/build_darwin.sh sign macapp
|
||||
env:
|
||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: |
|
||||
dist/Ollama-darwin.zip
|
||||
dist/ollama-darwin.tgz
|
||||
|
||||
windows-depends:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -103,21 +61,40 @@ jobs:
|
||||
arch: [amd64]
|
||||
preset: ['CPU']
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 11'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
cuda-version: '11.3'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v12'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v13'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runner_dir: 'rocm'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -141,7 +118,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
@@ -160,6 +137,9 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
@@ -178,11 +158,12 @@ jobs:
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}"
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
- uses: actions/upload-artifact@v4
|
||||
@@ -195,19 +176,19 @@ jobs:
|
||||
matrix:
|
||||
os: [windows]
|
||||
arch: [amd64, arm64]
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
llvmarch: x86_64
|
||||
- os: windows
|
||||
arch: arm64
|
||||
llvmarch: aarch64
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: [setup-environment]
|
||||
env:
|
||||
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
steps:
|
||||
- name: Install AMD64 system dependencies
|
||||
if: matrix.arch == 'amd64'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
||||
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- name: Install ARM64 system dependencies
|
||||
if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
@@ -219,72 +200,36 @@ jobs:
|
||||
|
||||
choco install -y --no-progress git gzip
|
||||
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
|
||||
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- name: Install clang and gcc-compat
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-${{ matrix.llvmarch }}.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt*").path
|
||||
echo "$installPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
- name: Verify gcc is actually clang
|
||||
run: |
|
||||
$ErrorActionPreference='Continue'
|
||||
$version=& gcc -v 2>&1
|
||||
$version=$version -join "`n"
|
||||
echo "gcc is $version"
|
||||
if ($version -notmatch 'clang') {
|
||||
echo "ERROR: GCC must be clang for proper utf16 handling"
|
||||
exit 1
|
||||
}
|
||||
$ErrorActionPreference='Stop'
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
||||
- run: |
|
||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
||||
& .\scripts\build_windows.ps1 buildApp
|
||||
env:
|
||||
VCToolsRedistDir: stub
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: |
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||
|
||||
windows-sign:
|
||||
runs-on: windows-2022
|
||||
environment: release
|
||||
needs: [windows-depends, windows-build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: google-github-actions/auth@v2
|
||||
with:
|
||||
project_id: ollama
|
||||
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
|
||||
- run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
|
||||
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
|
||||
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
|
||||
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
|
||||
|
||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: build-windows-*
|
||||
path: dist\
|
||||
merge-multiple: true
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows-amd64-*
|
||||
path: dist\windows-amd64\
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
||||
env:
|
||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: |
|
||||
dist\OllamaSetup.exe
|
||||
dist\ollama-windows-*.zip
|
||||
|
||||
linux-build:
|
||||
strategy:
|
||||
@@ -292,13 +237,13 @@ jobs:
|
||||
include:
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: archive
|
||||
target: archive_novulkan
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: rocm
|
||||
- os: linux
|
||||
arch: arm64
|
||||
target: archive
|
||||
target: archive_novulkan
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
@@ -317,21 +262,26 @@ jobs:
|
||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
esac
|
||||
done
|
||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
echo "Manifests"
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
|
||||
echo $ARCHIVE
|
||||
cat $ARCHIVE
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
@@ -349,12 +299,14 @@ jobs:
|
||||
include:
|
||||
- os: linux
|
||||
arch: arm64
|
||||
target: novulkan
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: novulkan
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
@@ -367,6 +319,14 @@ jobs:
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
FLAVOR=rocm
|
||||
- os: linux
|
||||
arch: amd64
|
||||
suffix: '-vulkan'
|
||||
target: default
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
@@ -384,9 +344,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
target: ${{ matrix.target }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||
@@ -418,7 +379,7 @@ jobs:
|
||||
latest=false
|
||||
suffix=${{ matrix.suffix }}
|
||||
images: |
|
||||
ollama/ollama
|
||||
${{ vars.DOCKER_REPO }}
|
||||
tags: |
|
||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
@@ -428,56 +389,24 @@ jobs:
|
||||
path: ${{ runner.temp }}
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
|
||||
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Trigger downstream release process
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
runs-on: linux
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends, linux-build]
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: dist-linux-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
- name: Create or update Release
|
||||
- name: Create or update Release for tag
|
||||
run: |
|
||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
|
||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||
if [ -n "$OLD_TAG" ]; then
|
||||
@@ -491,5 +420,12 @@ jobs:
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
fi
|
||||
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
|
||||
gh release upload ${GITHUB_REF_NAME} dist/* --clobber
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
|
||||
|
||||
62
.github/workflows/test.yaml
vendored
62
.github/workflows/test.yaml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -46,12 +46,18 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
extra-packages: rocm-libs
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||
- preset: Vulkan
|
||||
container: ubuntu:22.04
|
||||
extra-packages: >
|
||||
mesa-vulkan-drivers vulkan-tools
|
||||
libvulkan1 libvulkan-dev
|
||||
vulkan-sdk cmake ccache g++ make
|
||||
runs-on: linux
|
||||
container: ${{ matrix.container }}
|
||||
steps:
|
||||
@@ -59,7 +65,19 @@ jobs:
|
||||
- run: |
|
||||
[ -n "${{ matrix.container }}" ] || sudo=sudo
|
||||
$sudo apt-get update
|
||||
# Add LunarG Vulkan SDK apt repo for Ubuntu 22.04
|
||||
if [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||
$sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common
|
||||
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg
|
||||
# Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY
|
||||
echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null
|
||||
$sudo apt-get update
|
||||
fi
|
||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||
fi
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- uses: actions/cache@v4
|
||||
@@ -78,23 +96,35 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
- preset: Vulkan
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: matrix.preset == 'CUDA'
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
@@ -102,7 +132,8 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
@@ -120,6 +151,21 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'Vulkan'
|
||||
name: Install Vulkan ${{ matrix.rocm-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -133,8 +179,8 @@ jobs:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
env:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@
|
||||
dist
|
||||
build
|
||||
.cache
|
||||
.gocache
|
||||
*.exe
|
||||
.idea
|
||||
test_data
|
||||
|
||||
@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.21)
|
||||
project(Ollama C CXX)
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
@@ -37,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
@@ -51,7 +52,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
|
||||
add_compile_definitions(NDEBUG)
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
@@ -78,34 +79,36 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(908|90a|1200|1201):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(908|90a|1200|1201):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
if(CMAKE_HIP_COMPILER)
|
||||
set(HIP_PLATFORM "amd")
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
find_package(hip REQUIRED)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
endif()
|
||||
|
||||
if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
if(AMDGPU_TARGETS)
|
||||
find_package(hip REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
|
||||
if (WIN32)
|
||||
@@ -114,22 +117,37 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
RUNTIME_DEPENDENCY_SET rocm
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
POST_EXCLUDE_REGEXES "system32"
|
||||
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
)
|
||||
|
||||
foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR})
|
||||
if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas)
|
||||
install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP)
|
||||
install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"binaryDir": "${sourceDir}/build",
|
||||
"installDir": "${sourceDir}/dist",
|
||||
"cacheVariables": {
|
||||
"CMAKE_BUILD_TYPE": "Release"
|
||||
"CMAKE_BUILD_TYPE": "Release",
|
||||
"CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -21,16 +22,24 @@
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -58,8 +67,13 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"inherits": [ "Default" ]
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -88,6 +102,11 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 13"
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
@@ -107,6 +126,11 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"configurePreset": "ROCm 6"
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -65,7 +65,8 @@ continuation of the sentence:
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||
docs: simplify manual installation with shorter curl commands
|
||||
|
||||
Bad Examples:
|
||||
|
||||
|
||||
124
Dockerfile
124
Dockerfile
@@ -1,20 +1,33 @@
|
||||
# vim: filetype=dockerfile
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
ARG PARALLEL=8
|
||||
|
||||
ARG ROCMVERSION=6.3.3
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
ARG VULKANVERSION=1.4.321.1
|
||||
|
||||
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& dnf install -y ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
ARG VULKANVERSION
|
||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& dnf -y install ninja-build \
|
||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
@@ -33,35 +46,52 @@ ENV LDFLAGS=-s
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.3
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel 8
|
||||
cmake --preset 'ROCm 6' -DOLLAMA_RUNNER_DIR="rocm" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||
ARG CMAKEVERSION
|
||||
@@ -69,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 5' \
|
||||
&& cmake --build --parallel --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
cmake --preset 'JetPack 5' -DOLLAMA_RUNNER_DIR="cuda_jetpack5" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||
ARG CMAKEVERSION
|
||||
@@ -80,10 +111,18 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 6' \
|
||||
&& cmake --build --parallel --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
cmake --preset 'JetPack 6' -DOLLAMA_RUNNER_DIR="cuda_jetpack6" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS vulkan
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -94,31 +133,62 @@ RUN go mod download
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
ARG VULKANVERSION
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:20.04
|
||||
# Temporary opt-out stages for Vulkan
|
||||
FROM --platform=linux/amd64 scratch AS amd64_novulkan
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
FROM arm64 AS arm64_novulkan
|
||||
FROM ${FLAVOR}_novulkan AS archive_novulkan
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
FROM ubuntu:24.04 AS novulkan
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive_novulkan /bin /usr/bin
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV OLLAMA_HOST=0.0.0.0:11434
|
||||
EXPOSE 11434
|
||||
ENTRYPOINT ["/bin/ollama"]
|
||||
CMD ["serve"]
|
||||
|
||||
FROM ubuntu:24.04 AS default
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
COPY --from=archive /lib/ollama /usr/lib/ollama
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
|
||||
FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@@ -12,7 +12,7 @@ help:
|
||||
@echo " clean Clean local repository"
|
||||
@echo
|
||||
@echo "Example:"
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches sync"
|
||||
|
||||
.PHONY: sync
|
||||
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
|
||||
@@ -24,12 +24,12 @@ ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
|
||||
go generate ./$(@D)
|
||||
|
||||
.PHONY: llama/llama.cpp
|
||||
llama/llama.cpp: llama/vendor/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
llama/llama.cpp: llama/vendor
|
||||
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /) $@
|
||||
|
||||
.PHONY: ml/backend/ggml/ggml
|
||||
ml/backend/ggml/ggml: llama/vendor/ggml/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
ml/backend/ggml/ggml: llama/vendor
|
||||
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /ggml/) $@
|
||||
|
||||
PATCHES=$(wildcard llama/patches/*.patch)
|
||||
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
|
||||
@@ -39,7 +39,15 @@ PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATC
|
||||
apply-patches: $(PATCHED)
|
||||
|
||||
llama/patches/.%.patched: llama/patches/%.patch
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then \
|
||||
touch $@; \
|
||||
else \
|
||||
echo "Patch failed. Resolve any conflicts then continue."; \
|
||||
echo "1. Run 'git -C $(WORKDIR) am --continue'"; \
|
||||
echo "2. Run 'make -f $(lastword $(MAKEFILE_LIST)) format-patches'"; \
|
||||
echo "3. Run 'make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches'"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
.PHONY: checkout
|
||||
checkout: $(WORKDIR)
|
||||
@@ -60,4 +68,5 @@ format-patches: llama/patches
|
||||
|
||||
.PHONE: clean
|
||||
clean: checkout
|
||||
@git -C $(WORKDIR) am --abort || true
|
||||
$(RM) llama/patches/.*.patched
|
||||
|
||||
23
README.md
23
README.md
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,7 @@ Get up and running with large language models.
|
||||
|
||||
### macOS
|
||||
|
||||
[Download](https://ollama.com/download/Ollama-darwin.zip)
|
||||
[Download](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
@@ -360,7 +360,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
@@ -409,6 +409,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -454,6 +460,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter. Ollama integration for running local/cloud models with configurable endpoints.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -534,6 +542,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -592,10 +604,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||
|
||||
@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
authError := AuthorizationError{StatusCode: resp.StatusCode}
|
||||
json.Unmarshal(body, &authError)
|
||||
return authError
|
||||
}
|
||||
|
||||
apiError := StatusError{StatusCode: resp.StatusCode}
|
||||
|
||||
err := json.Unmarshal(body, &apiError)
|
||||
@@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
SigninURL string `json:"signin_url,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
@@ -222,11 +229,13 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
return AuthorizationError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
SigninURL: errorResponse.SigninURL,
|
||||
}
|
||||
} else if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
@@ -234,6 +243,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if err := fn(bts); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -428,3 +441,21 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// Signout will signout a client for a local ollama server.
|
||||
func (c *Client) Signout(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||
}
|
||||
|
||||
// Disconnect will disconnect an ollama instance from ollama.com.
|
||||
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
|
||||
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
var resp UserResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -89,6 +89,16 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
wantErr: "mid-stream error",
|
||||
},
|
||||
{
|
||||
name: "http status error takes precedence over general error",
|
||||
responses: []any{
|
||||
testError{
|
||||
message: "custom error message",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
},
|
||||
wantErr: "500",
|
||||
},
|
||||
{
|
||||
name: "successful stream completion",
|
||||
responses: []any{
|
||||
|
||||
382
api/types.go
382
api/types.go
@@ -11,6 +11,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -36,6 +38,19 @@ func (e StatusError) Error() string {
|
||||
}
|
||||
}
|
||||
|
||||
type AuthorizationError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
SigninURL string `json:"signin_url"`
|
||||
}
|
||||
|
||||
func (e AuthorizationError) Error() string {
|
||||
if e.Status != "" {
|
||||
return e.Status
|
||||
}
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
@@ -85,10 +100,23 @@ type GenerateRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// Truncate is a boolean that, when set to true, truncates the chat history messages
|
||||
// if the rendered prompt exceeds the context length limit.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Shift is a boolean that, when set to true, shifts the chat history
|
||||
// when hitting the context length limit instead of erroring.
|
||||
Shift *bool `json:"shift,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -116,8 +144,21 @@ type ChatRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// Truncate is a boolean that, when set to true, truncates the chat history messages
|
||||
// if the rendered prompt exceeds the context length limit.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Shift is a boolean that, when set to true, shifts the chat history
|
||||
// when hitting the context length limit instead of erroring.
|
||||
Shift *bool `json:"shift,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -143,6 +184,7 @@ type Message struct {
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
@@ -162,7 +204,7 @@ type ToolCall struct {
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
@@ -222,21 +264,76 @@ func (pt PropertyType) String() string {
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
func (tp ToolProperty) ToTypeScriptType() string {
|
||||
if len(tp.AnyOf) > 0 {
|
||||
var types []string
|
||||
for _, anyOf := range tp.AnyOf {
|
||||
types = append(types, anyOf.ToTypeScriptType())
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
if len(tp.Type) == 0 {
|
||||
return "any"
|
||||
}
|
||||
|
||||
if len(tp.Type) == 1 {
|
||||
return mapToTypeScriptType(tp.Type[0])
|
||||
}
|
||||
|
||||
var types []string
|
||||
for _, t := range tp.Type {
|
||||
types = append(types, mapToTypeScriptType(t))
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
// mapToTypeScriptType maps JSON Schema types to TypeScript types
|
||||
func mapToTypeScriptType(jsonType string) string {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return "string"
|
||||
case "number", "integer":
|
||||
return "number"
|
||||
case "boolean":
|
||||
return "boolean"
|
||||
case "array":
|
||||
return "any[]"
|
||||
case "object":
|
||||
return "Record<string, any>"
|
||||
case "null":
|
||||
return "null"
|
||||
default:
|
||||
return "any"
|
||||
}
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
@@ -247,16 +344,38 @@ func (t *ToolFunction) String() string {
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Message contains the message or part of a message from the model.
|
||||
Message Message `json:"message"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
// DoneReason is the reason the model stopped generating text.
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
// DebugInfo contains debug information for template rendering
|
||||
type DebugInfo struct {
|
||||
RenderedTemplate string `json:"rendered_template"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
@@ -309,8 +428,12 @@ type EmbedRequest struct {
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Truncate truncates the input to fit the model's max sequence length.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Dimensions truncates the output embedding to the specified dimension.
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
@@ -348,18 +471,47 @@ type EmbeddingResponse struct {
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
// Model is the model name to create.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
|
||||
From string `json:"from,omitempty"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
License any `json:"license,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
// From is the name of the model or file to use as the source.
|
||||
From string `json:"from,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream ollama API for the model (if any).
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// Files is a map of files include when creating the model.
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
|
||||
// Adapters is a map of LoRA adapters to include when creating the model.
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
|
||||
// Template is the template used when constructing a request to the model.
|
||||
Template string `json:"template,omitempty"`
|
||||
|
||||
// License is a string or list of strings for licenses.
|
||||
License any `json:"license,omitempty"`
|
||||
|
||||
// System is the system prompt for the model.
|
||||
System string `json:"system,omitempty"`
|
||||
|
||||
// Parameters is a map of hyper-parameters which are applied to the model.
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
|
||||
// Messages is a list of messages added to the model before chat and generation requests.
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -397,8 +549,12 @@ type ShowResponse struct {
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
@@ -457,23 +613,26 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
type ProcessModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
ContextLength int `json:"context_length"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
@@ -485,6 +644,12 @@ type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -506,6 +671,10 @@ type GenerateResponse struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
@@ -518,6 +687,18 @@ type ModelDetails struct {
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
// UserResponse provides information about a user.
|
||||
type UserResponse struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
AvatarURL string `json:"avatarurl,omitempty"`
|
||||
FirstName string `json:"firstname,omitempty"`
|
||||
LastName string `json:"lastname,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
@@ -675,6 +856,113 @@ func DefaultOptions() Options {
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||
type ThinkValue struct {
|
||||
// Value can be a bool or string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// IsValid checks if the ThinkValue is valid
|
||||
func (t *ThinkValue) IsValid() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return true // nil is valid (means not set)
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return true
|
||||
case string:
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsBool returns true if the value is a boolean
|
||||
func (t *ThinkValue) IsBool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(bool)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsString returns true if the value is a string
|
||||
func (t *ThinkValue) IsString() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool (true if enabled in any way)
|
||||
func (t *ThinkValue) Bool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
// Any string value ("high", "medium", "low") means thinking is enabled
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the value as a string
|
||||
func (t *ThinkValue) String() string {
|
||||
if t == nil || t.Value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case bool:
|
||||
if v {
|
||||
return "medium" // Default level when just true
|
||||
}
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as bool first
|
||||
var b bool
|
||||
if err := json.Unmarshal(data, &b); err == nil {
|
||||
t.Value = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as string
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
// Validate string values
|
||||
if s != "high" && s != "medium" && s != "low" {
|
||||
return fmt.Errorf("invalid think value: %q (must be \"high\", \"medium\", \"low\", true, or false)", s)
|
||||
}
|
||||
t.Value = s
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (t *ThinkValue) MarshalJSON() ([]byte, error) {
|
||||
if t == nil || t.Value == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.Value)
|
||||
}
|
||||
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
@@ -699,7 +987,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
|
||||
@@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req string
|
||||
exp *Duration
|
||||
}{
|
||||
{
|
||||
name: "Unset",
|
||||
req: `{ }`,
|
||||
exp: nil,
|
||||
},
|
||||
{
|
||||
name: "Positive Integer",
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
@@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42 * time.Second},
|
||||
exp: &Duration{42500 * time.Millisecond},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
@@ -293,6 +298,30 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||
fn := ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fn)
|
||||
require.NoError(t, err)
|
||||
|
||||
raw := map[string]any{}
|
||||
require.NoError(t, json.Unmarshal(data, &raw))
|
||||
require.Contains(t, raw, "index")
|
||||
assert.Equal(t, float64(0), raw["index"])
|
||||
|
||||
fn.Index = 3
|
||||
data, err = json.Marshal(fn)
|
||||
require.NoError(t, err)
|
||||
|
||||
raw = map[string]any{}
|
||||
require.NoError(t, json.Unmarshal(data, &raw))
|
||||
require.Contains(t, raw, "index")
|
||||
assert.Equal(t, float64(3), raw["index"])
|
||||
}
|
||||
|
||||
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -374,24 +403,21 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedThinking *ThinkValue
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
expectedThinking: &ThinkValue{Value: true},
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
expectedThinking: &ThinkValue{Value: false},
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
@@ -399,8 +425,23 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
name: "string_high",
|
||||
input: `{ "think": "high" }`,
|
||||
expectedThinking: &ThinkValue{Value: "high"},
|
||||
},
|
||||
{
|
||||
name: "string_medium",
|
||||
input: `{ "think": "medium" }`,
|
||||
expectedThinking: &ThinkValue{Value: "medium"},
|
||||
},
|
||||
{
|
||||
name: "string_low",
|
||||
input: `{ "think": "low" }`,
|
||||
expectedThinking: &ThinkValue{Value: "low"},
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
input: `{ "think": "invalid" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
@@ -414,8 +455,60 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
if test.expectedThinking == nil {
|
||||
assert.Nil(t, req.Think)
|
||||
} else {
|
||||
require.NotNil(t, req.Think)
|
||||
assert.Equal(t, test.expectedThinking.Value, req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params ToolFunctionParameters
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple object with string property",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
{
|
||||
name: "marshal failure returns empty string",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: func() any {
|
||||
// Create a cycle that will cause json.Marshal to fail
|
||||
type selfRef struct {
|
||||
Self *selfRef
|
||||
}
|
||||
s := &selfRef{}
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.params.String()
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
142
api/types_typescript_test.go
Normal file
142
api/types_typescript_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestToolParameterToTypeScriptType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
param ToolProperty
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single string type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string"},
|
||||
},
|
||||
expected: "string",
|
||||
},
|
||||
{
|
||||
name: "single number type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"number"},
|
||||
},
|
||||
expected: "number",
|
||||
},
|
||||
{
|
||||
name: "integer maps to number",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"integer"},
|
||||
},
|
||||
expected: "number",
|
||||
},
|
||||
{
|
||||
name: "boolean type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"boolean"},
|
||||
},
|
||||
expected: "boolean",
|
||||
},
|
||||
{
|
||||
name: "array type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"array"},
|
||||
},
|
||||
expected: "any[]",
|
||||
},
|
||||
{
|
||||
name: "object type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
},
|
||||
expected: "Record<string, any>",
|
||||
},
|
||||
{
|
||||
name: "null type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"null"},
|
||||
},
|
||||
expected: "null",
|
||||
},
|
||||
{
|
||||
name: "multiple types as union",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "number"},
|
||||
},
|
||||
expected: "string | number",
|
||||
},
|
||||
{
|
||||
name: "string or null union",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "null"},
|
||||
},
|
||||
expected: "string | null",
|
||||
},
|
||||
{
|
||||
name: "anyOf with single types",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
expected: "string | number",
|
||||
},
|
||||
{
|
||||
name: "anyOf with multiple types in each branch",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string", "null"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
expected: "string | null | number",
|
||||
},
|
||||
{
|
||||
name: "nested anyOf",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"boolean"}},
|
||||
{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "boolean | string | number",
|
||||
},
|
||||
{
|
||||
name: "empty type returns any",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{},
|
||||
},
|
||||
expected: "any",
|
||||
},
|
||||
{
|
||||
name: "unknown type maps to any",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"unknown_type"},
|
||||
},
|
||||
expected: "any",
|
||||
},
|
||||
{
|
||||
name: "multiple types including array",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "array", "null"},
|
||||
},
|
||||
expected: "string | any[] | null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.param.ToTypeScriptType()
|
||||
if result != tt.expected {
|
||||
t.Errorf("ToTypeScriptType() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
15
auth/auth.go
15
auth/auth.go
@@ -18,21 +18,13 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
func GetPublicKey() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
@@ -59,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
332
cmd/cmd.go
332
cmd/cmd.go
@@ -47,6 +47,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
@@ -56,10 +58,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
@@ -288,7 +288,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
if r.RemoteModel != "" && opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -309,9 +319,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@@ -322,11 +333,23 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
thinkStr, err := cmd.Flags().GetString("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
|
||||
// Handle different values for --think
|
||||
switch thinkStr {
|
||||
case "", "true":
|
||||
// --think or --think=true
|
||||
opts.Think = &api.ThinkValue{Value: true}
|
||||
case "false":
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
case "high", "medium", "low":
|
||||
opts.Think = &api.ThinkValue{Value: thinkStr}
|
||||
default:
|
||||
return fmt.Errorf("invalid value for --think: %q (must be true, false, high, medium, or low)", thinkStr)
|
||||
}
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
@@ -357,6 +380,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.ShowConnect = false
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
@@ -423,6 +447,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
|
||||
if sErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, sErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -443,6 +476,59 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.Name != "" {
|
||||
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.Signout(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You are not signed in to ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("You have signed out of ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -454,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
_, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to push models to ollama.com.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@@ -490,12 +595,12 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
errStr := strings.ToLower(err.Error())
|
||||
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
return err
|
||||
@@ -529,7 +634,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
var size string
|
||||
if m.RemoteModel != "" {
|
||||
size = "-"
|
||||
} else {
|
||||
size = format.HumanBytes(m.Size)
|
||||
}
|
||||
|
||||
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -583,12 +695,13 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
|
||||
} else {
|
||||
until = format.HumanTime(m.ExpiresAt, "Never")
|
||||
}
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
|
||||
ctxStr := strconv.Itoa(m.ContextLength)
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
|
||||
}
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"})
|
||||
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetHeaderLine(false)
|
||||
@@ -613,8 +726,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -725,12 +838,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
@@ -976,8 +1113,54 @@ type runOptions struct {
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
}
|
||||
|
||||
func (r runOptions) Copy() runOptions {
|
||||
var messages []api.Message
|
||||
if r.Messages != nil {
|
||||
messages = make([]api.Message, len(r.Messages))
|
||||
copy(messages, r.Messages)
|
||||
}
|
||||
|
||||
var images []api.ImageData
|
||||
if r.Images != nil {
|
||||
images = make([]api.ImageData, len(r.Images))
|
||||
copy(images, r.Images)
|
||||
}
|
||||
|
||||
var opts map[string]any
|
||||
if r.Options != nil {
|
||||
opts = make(map[string]any, len(r.Options))
|
||||
for k, v := range r.Options {
|
||||
opts[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Think != nil {
|
||||
cThink := *r.Think
|
||||
think = &cThink
|
||||
}
|
||||
|
||||
return runOptions{
|
||||
Model: r.Model,
|
||||
ParentModel: r.ParentModel,
|
||||
Prompt: r.Prompt,
|
||||
Messages: messages,
|
||||
WordWrap: r.WordWrap,
|
||||
Format: r.Format,
|
||||
System: r.System,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
MultiModal: r.MultiModal,
|
||||
KeepAlive: r.KeepAlive,
|
||||
Think: think,
|
||||
HideThinking: r.HideThinking,
|
||||
ShowConnect: r.ShowConnect,
|
||||
}
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -1016,10 +1199,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case ' ':
|
||||
case ' ', '\t':
|
||||
state.wordBuffer = ""
|
||||
case '\n':
|
||||
case '\n', '\r':
|
||||
state.lineLength = 0
|
||||
state.wordBuffer = ""
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
@@ -1077,12 +1261,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
role := "assistant"
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
@@ -1095,14 +1281,21 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Message.Thinking)
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
@@ -1110,6 +1303,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1135,6 +1335,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1186,6 +1394,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
@@ -1203,17 +1412,31 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Thinking)
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
if response.ToolCalls != nil {
|
||||
toolCalls := response.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, plainText))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1416,13 +1639,13 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
@@ -1453,7 +1676,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models")
|
||||
runCmd.Flags().Lookup("think").NoOptDefVal = "true"
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
@@ -1492,6 +1716,22 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
signinCmd := &cobra.Command{
|
||||
Use: "signin",
|
||||
Short: "Sign in to ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
signoutCmd := &cobra.Command{
|
||||
Use: "signout",
|
||||
Short: "Sign out from ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
@@ -1558,6 +1798,7 @@ func NewCLI() *cobra.Command {
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
envVars["OLLAMA_DEBUG"],
|
||||
envVars["OLLAMA_HOST"],
|
||||
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||
envVars["OLLAMA_KEEP_ALIVE"],
|
||||
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
||||
envVars["OLLAMA_MAX_QUEUE"],
|
||||
@@ -1585,6 +1826,8 @@ func NewCLI() *cobra.Command {
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
signoutCmd,
|
||||
listCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
@@ -1603,7 +1846,7 @@ func NewCLI() *cobra.Command {
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*api.ThinkValue, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
@@ -1630,9 +1873,34 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
return &api.ThinkValue{Value: true}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
}
|
||||
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
328
cmd/cmd_test.go
328
cmd/cmd_test.go
@@ -3,10 +3,12 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -304,6 +306,8 @@ func TestDeleteHandler(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
errPayload := `{"error":"model '%s' not found"}`
|
||||
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -346,7 +350,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
err := DeleteHandler(cmd, []string{"test-model-not-found"})
|
||||
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
|
||||
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
|
||||
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -488,9 +492,35 @@ func TestPushHandler(t *testing.T) {
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
},
|
||||
{
|
||||
name: "not signed in push",
|
||||
modelName: "notsignedin-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://somethingsomething",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "You need to be signed in to push",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
@@ -499,12 +529,17 @@ func TestPushHandler(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "access denied",
|
||||
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
@@ -522,6 +557,10 @@ func TestPushHandler(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("HOME", tmpDir)
|
||||
t.Setenv("USERPROFILE", tmpDir)
|
||||
initializeKeypair()
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
@@ -557,7 +596,7 @@ func TestPushHandler(t *testing.T) {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
@@ -915,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy(t *testing.T) {
|
||||
// Setup test data
|
||||
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
|
||||
originalThink := &api.ThinkValue{Value: "test reasoning"}
|
||||
|
||||
original := runOptions{
|
||||
Model: "test-model",
|
||||
ParentModel: "parent-model",
|
||||
Prompt: "test prompt",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi there"},
|
||||
},
|
||||
WordWrap: true,
|
||||
Format: "json",
|
||||
System: "system prompt",
|
||||
Images: []api.ImageData{
|
||||
[]byte("image1"),
|
||||
[]byte("image2"),
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
MultiModal: true,
|
||||
KeepAlive: originalKeepAlive,
|
||||
Think: originalThink,
|
||||
HideThinking: false,
|
||||
ShowConnect: true,
|
||||
}
|
||||
|
||||
// Test the copy
|
||||
copied := original.Copy()
|
||||
|
||||
// Test 1: Verify the copy is not the same instance
|
||||
if &copied == &original {
|
||||
t.Error("Copy should return a different instance")
|
||||
}
|
||||
|
||||
// Test 2: Verify all fields are copied correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"Model", copied.Model, original.Model},
|
||||
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||
{"Prompt", copied.Prompt, original.Prompt},
|
||||
{"WordWrap", copied.WordWrap, original.WordWrap},
|
||||
{"Format", copied.Format, original.Format},
|
||||
{"System", copied.System, original.System},
|
||||
{"MultiModal", copied.MultiModal, original.MultiModal},
|
||||
{"HideThinking", copied.HideThinking, original.HideThinking},
|
||||
{"ShowConnect", copied.ShowConnect, original.ShowConnect},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if !reflect.DeepEqual(tt.got, tt.want) {
|
||||
t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 3: Verify Messages slice is deeply copied
|
||||
if len(copied.Messages) != len(original.Messages) {
|
||||
t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages))
|
||||
}
|
||||
|
||||
if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] {
|
||||
t.Error("Messages should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Messages) > 0 {
|
||||
originalContent := original.Messages[0].Content
|
||||
original.Messages[0].Content = "modified"
|
||||
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||
t.Error("Messages should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Messages[0].Content = originalContent
|
||||
}
|
||||
|
||||
// Test 4: Verify Images slice is deeply copied
|
||||
if len(copied.Images) != len(original.Images) {
|
||||
t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images))
|
||||
}
|
||||
|
||||
if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] {
|
||||
t.Error("Images should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Images) > 0 {
|
||||
originalImage := original.Images[0]
|
||||
original.Images[0] = []byte("modified")
|
||||
if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" {
|
||||
t.Error("Images should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Images[0] = originalImage
|
||||
}
|
||||
|
||||
// Test 5: Verify Options map is deeply copied
|
||||
if len(copied.Options) != len(original.Options) {
|
||||
t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options))
|
||||
}
|
||||
|
||||
if len(copied.Options) > 0 && &copied.Options == &original.Options {
|
||||
t.Error("Options map should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Options) > 0 {
|
||||
originalTemp := original.Options["temperature"]
|
||||
original.Options["temperature"] = 0.9
|
||||
if copied.Options["temperature"] == 0.9 {
|
||||
t.Error("Options should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Options["temperature"] = originalTemp
|
||||
}
|
||||
|
||||
// Test 6: Verify KeepAlive pointer is copied (shallow copy)
|
||||
if copied.KeepAlive != original.KeepAlive {
|
||||
t.Error("KeepAlive pointer should be the same (shallow copy)")
|
||||
}
|
||||
|
||||
// Test 7: Verify Think pointer creates a new instance
|
||||
if original.Think != nil && copied.Think == original.Think {
|
||||
t.Error("Think should be a different instance")
|
||||
}
|
||||
|
||||
if original.Think != nil && copied.Think != nil {
|
||||
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 8: Test with zero values
|
||||
zeroOriginal := runOptions{}
|
||||
zeroCopy := zeroOriginal.Copy()
|
||||
|
||||
if !reflect.DeepEqual(zeroCopy, zeroOriginal) {
|
||||
fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy)
|
||||
t.Error("Copy of zero value should equal original zero value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
||||
// Test with empty slices and maps
|
||||
original := runOptions{
|
||||
Messages: []api.Message{},
|
||||
Images: []api.ImageData{},
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
if copied.Messages == nil {
|
||||
t.Error("Empty Messages slice should remain empty, not nil")
|
||||
}
|
||||
|
||||
if copied.Images == nil {
|
||||
t.Error("Empty Images slice should remain empty, not nil")
|
||||
}
|
||||
|
||||
if copied.Options == nil {
|
||||
t.Error("Empty Options map should remain empty, not nil")
|
||||
}
|
||||
|
||||
if len(copied.Messages) != 0 {
|
||||
t.Error("Empty Messages slice should remain empty")
|
||||
}
|
||||
|
||||
if len(copied.Images) != 0 {
|
||||
t.Error("Empty Images slice should remain empty")
|
||||
}
|
||||
|
||||
if len(copied.Options) != 0 {
|
||||
t.Error("Empty Options map should remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_NilPointers(t *testing.T) {
|
||||
// Test with nil pointers
|
||||
original := runOptions{
|
||||
KeepAlive: nil,
|
||||
Think: nil,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
if copied.KeepAlive != nil {
|
||||
t.Error("Nil KeepAlive should remain nil")
|
||||
}
|
||||
|
||||
if copied.Think != nil {
|
||||
t.Error("Nil Think should remain nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
think *api.ThinkValue
|
||||
}{
|
||||
{"nil Think", nil},
|
||||
{"bool true", &api.ThinkValue{Value: true}},
|
||||
{"bool false", &api.ThinkValue{Value: false}},
|
||||
{"string value", &api.ThinkValue{Value: "reasoning text"}},
|
||||
{"int value", &api.ThinkValue{Value: 42}},
|
||||
{"nil value", &api.ThinkValue{Value: nil}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
original := runOptions{Think: tt.think}
|
||||
copied := original.Copy()
|
||||
|
||||
if tt.think == nil {
|
||||
if copied.Think != nil {
|
||||
t.Error("Nil Think should remain nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if copied.Think == nil {
|
||||
t.Error("Non-nil Think should not become nil")
|
||||
return
|
||||
}
|
||||
|
||||
if copied.Think == original.Think {
|
||||
t.Error("Think should be a different instance")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
original := runOptions{
|
||||
Model: "original-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "original"}},
|
||||
Options: map[string]any{"key": "value"},
|
||||
Think: originalThink,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
// Modify original
|
||||
original.Model = "modified-model"
|
||||
if len(original.Messages) > 0 {
|
||||
original.Messages[0].Content = "modified"
|
||||
}
|
||||
original.Options["key"] = "modified"
|
||||
if original.Think != nil {
|
||||
original.Think.Value = "modified"
|
||||
}
|
||||
|
||||
// Verify copy is unchanged
|
||||
if copied.Model == "modified-model" {
|
||||
t.Error("Copy Model should not be affected by original modification")
|
||||
}
|
||||
|
||||
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||
t.Error("Copy Messages should not be affected by original modification")
|
||||
}
|
||||
|
||||
if copied.Options["key"] == "modified" {
|
||||
t.Error("Copy Options should not be affected by original modification")
|
||||
}
|
||||
|
||||
if copied.Think != nil && copied.Think.Value == "modified" {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Println("Usage:\n /load <modelname>")
|
||||
continue
|
||||
}
|
||||
origOpts := opts.Copy()
|
||||
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
@@ -272,16 +280,29 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
// TODO(drifkin): validate the level, could be model dependent
|
||||
// though... It will also be validated on the server once a call is
|
||||
// made.
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
opts.Think = &thinkValue
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
@@ -385,18 +406,21 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println("No parameters were specified for this model.")
|
||||
fmt.Println(" No additional parameters were specified for this model.")
|
||||
} else {
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
fmt.Println("Model defined parameters:")
|
||||
fmt.Println(resp.Parameters)
|
||||
}
|
||||
fmt.Println()
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf(" %-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
@@ -475,7 +499,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||
strings.Contains(err.Error(), "invalid think value") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
|
||||
@@ -190,6 +190,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
@@ -200,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type bertModel struct {
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
normalizeEmbeddings bool
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
break
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
|
||||
165
convert/convert_gemma3n.go
Normal file
165
convert/convert_gemma3n.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"gonum.org/v1/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
type gemma3nModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextModel struct {
|
||||
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, t := range m.TextModel.LayerTypes {
|
||||
if !yield(t == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out, ts := mergeTensors(ts,
|
||||
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||
)
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "audio_tower"),
|
||||
strings.Contains(t.Name(), "embed_audio"),
|
||||
strings.Contains(t.Name(), "vision_tower"),
|
||||
strings.Contains(t.Name(), "embed_vision"):
|
||||
// TODO: handle audio and vision towers
|
||||
continue
|
||||
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||
if m.TextModel.AltupCoefClip > 0 {
|
||||
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
"altup.", "altup_",
|
||||
"modality_router", "router",
|
||||
"prediction_coefs", "predict_coef",
|
||||
"correction_coefs", "correct_coef",
|
||||
"correct_output_scale", "correct_scale.weight",
|
||||
"laurel.", "laurel_",
|
||||
"linear_left", "l",
|
||||
"linear_right", "r",
|
||||
"post_laurel_norm", "post_norm",
|
||||
}
|
||||
}
|
||||
266
convert/convert_gptoss.go
Normal file
266
convert/convert_gptoss.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type gptossModel struct {
|
||||
ModelParameters
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
LocalExperts uint32 `json:"num_local_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||
kv["gptoss.block_count"] = m.HiddenLayers
|
||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||
kv["gptoss.attention.key_length"] = m.HeadDim
|
||||
kv["gptoss.attention.value_length"] = m.HeadDim
|
||||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|>
|
||||
kv["tokenizer.ggml.eos_token_ids"] = []int32{
|
||||
199999, /* <|endoftext|> */
|
||||
200002, /* <|return|> */
|
||||
200012, /* <|call|> */
|
||||
}
|
||||
kv["tokenizer.ggml.add_eos_token"] = false
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
mxfp4s := make(map[string]*mxfp4)
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") {
|
||||
dot := strings.LastIndex(t.Name(), ".")
|
||||
name, suffix := t.Name()[:dot], t.Name()[dot+1:]
|
||||
if _, ok := mxfp4s[name]; !ok {
|
||||
mxfp4s[name] = &mxfp4{}
|
||||
}
|
||||
|
||||
switch suffix {
|
||||
case "blocks":
|
||||
mxfp4s[name].blocks = t
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, slices.Collect(splitDim(t, 1,
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
} else if strings.Contains(name, "ffn_gate_up_exps") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||
}, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gptossModel) Replacements() []string {
|
||||
var replacements []string
|
||||
if m.MaxPositionEmbeddings > 0 {
|
||||
// hf flavored model
|
||||
replacements = []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
"self_attn.sinks", "attn_sinks",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
} else {
|
||||
replacements = []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
slices []tensor.Slice
|
||||
|
||||
blocks, scales Tensor
|
||||
}
|
||||
|
||||
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
|
||||
slice[dim] = tensor.S(start, end, step)
|
||||
return &mxfp4{
|
||||
slices: slice,
|
||||
blocks: m.blocks,
|
||||
scales: m.scales,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
blocksDims := make([]int, len(m.blocks.Shape()))
|
||||
for i, d := range m.blocks.Shape() {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape()))
|
||||
for i, d := range m.scales.Shape() {
|
||||
scalesDims[i] = int(d)
|
||||
}
|
||||
|
||||
var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes()))
|
||||
|
||||
out, err := tensor.Concat(3, scales, blocks)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(m.slices) > 0 {
|
||||
out, err = out.Slice(m.slices...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
u8s, err := native.VectorU8(out.(*tensor.Dense))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(u8s)), nil
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -30,65 +27,38 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
"w2", "ffn_down_exps",
|
||||
"w3", "ffn_up_exps",
|
||||
}
|
||||
|
||||
for i := range p.NumLocalExperts {
|
||||
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
||||
}
|
||||
|
||||
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
||||
namer := strings.NewReplacer(oldnew...)
|
||||
experts := make(map[string]experts)
|
||||
|
||||
// merge experts into a single tensor while removing them from ts
|
||||
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
||||
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
||||
return false
|
||||
}
|
||||
|
||||
name := namer.Replace(t.Name())
|
||||
experts[name] = append(experts[name], t)
|
||||
return true
|
||||
})
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
WriterTo: e,
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
||||
for i := range p.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
||||
})
|
||||
}
|
||||
|
||||
out, ts := mergeTensors(ts, merges...)
|
||||
return append(out, p.llamaModel.Tensors(ts)...)
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Replacements() []string {
|
||||
return append(
|
||||
p.llamaModel.Replacements(),
|
||||
"model.layers", "blk",
|
||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||
"block_sparse_moe.experts.", ".",
|
||||
)
|
||||
}
|
||||
|
||||
type experts []Tensor
|
||||
|
||||
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
||||
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
||||
for _, t := range e {
|
||||
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
||||
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
||||
// this accomplishes the same thing by writing each expert tensor in sequence
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -11,14 +11,14 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -137,9 +137,7 @@ func TestConvertModel(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keys := maps.Keys(expect)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != expect[k] {
|
||||
@@ -342,15 +340,8 @@ func TestConvertAdapter(t *testing.T) {
|
||||
}
|
||||
|
||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||
|
||||
keys := maps.Keys(c.Expected)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != c.Expected[k] {
|
||||
t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
|
||||
}
|
||||
if diff := cmp.Diff(c.Expected, actual); diff != "" {
|
||||
t.Errorf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -31,28 +31,31 @@ func (t tensorBase) Shape() []uint64 {
|
||||
}
|
||||
|
||||
const (
|
||||
tensorKindF32 uint32 = iota
|
||||
tensorKindF16
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
switch len(t.shape) {
|
||||
case 0:
|
||||
panic("invalid tensor shape")
|
||||
case 1:
|
||||
return tensorKindF32
|
||||
return tensorKindFP32
|
||||
default:
|
||||
return tensorKindF16
|
||||
return tensorKindFP16
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
@@ -46,8 +47,7 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := maps.Keys(headers)
|
||||
slices.Sort(keys)
|
||||
keys := slices.Sorted(maps.Keys(headers))
|
||||
|
||||
names := make(map[string]struct{}, len(keys))
|
||||
|
||||
@@ -94,6 +94,15 @@ type safetensor struct {
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
return kind
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
@@ -116,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if seeker, ok := f.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
||||
return 0, err
|
||||
r, err := func() (io.Reader, error) {
|
||||
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||
return io.NewSectionReader(readerAt, st.offset, st.size), nil
|
||||
} else if seeker, ok := f.(io.Seeker); ok {
|
||||
_, err := seeker.Seek(st.offset, io.SeekStart)
|
||||
return f, err
|
||||
} else {
|
||||
_, err := io.CopyN(io.Discard, f, st.offset)
|
||||
return f, err
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
br := bufio.NewReaderSize(r, min(32<<10, int(st.size)))
|
||||
// special case when input and output are same type and the
|
||||
// tensor doesn't need repacking
|
||||
if (st.repacker == nil) &&
|
||||
((st.dtype == "F32" && st.Kind() == tensorKindFP32) ||
|
||||
(st.dtype == "F16" && st.Kind() == tensorKindFP16) ||
|
||||
(st.dtype == "U8")) {
|
||||
return io.CopyN(w, br, st.size)
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
switch st.dtype {
|
||||
case "F32":
|
||||
f32s = make([]float32, st.size/4)
|
||||
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case "F16":
|
||||
u16s := make([]uint16, st.size/2)
|
||||
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -146,7 +170,7 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
case "BF16":
|
||||
u8s := make([]uint8, st.size)
|
||||
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -163,15 +187,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindF32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindF16:
|
||||
case tensorKindFP32:
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
294
convert/reader_test.go
Normal file
294
convert/reader_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func TestSafetensors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
cases := []struct {
|
||||
name,
|
||||
dtype string
|
||||
offset,
|
||||
size int64
|
||||
shape []uint64
|
||||
setup func(*testing.T, *os.File)
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "fp32-fp32",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp32-fp16",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp16",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp32",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-bf16",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
||||
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
||||
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
||||
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-fp32",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "u8-u8",
|
||||
dtype: "U8",
|
||||
size: 32, // 32 brain floats, each 1 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u8s := make([]uint8, 32)
|
||||
for i := range u8s {
|
||||
u8s[i] = uint8(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := filepath.Base(t.Name())
|
||||
st := safetensor{
|
||||
fs: root.FS(),
|
||||
path: path,
|
||||
dtype: tt.dtype,
|
||||
offset: tt.offset,
|
||||
size: tt.size,
|
||||
tensorBase: &tensorBase{
|
||||
name: tt.name,
|
||||
shape: tt.shape,
|
||||
},
|
||||
}
|
||||
|
||||
f, err := root.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tt.setup(t, f)
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := st.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
||||
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
st safetensor
|
||||
expected uint32
|
||||
}{
|
||||
{
|
||||
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindBF16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with v. prefix should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "v.weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with FP32 base kind should return FP32",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10}, // will default to FP32
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP32,
|
||||
},
|
||||
{
|
||||
name: "Non-BF16 dtype should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "FP16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.st.Kind()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,9 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -14,7 +16,8 @@ import (
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
dim int
|
||||
slices []tensor.Slice
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
@@ -30,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
slice := split.slices
|
||||
if len(slice) == 0 {
|
||||
slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
}
|
||||
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
@@ -74,3 +80,54 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
pattern, name string
|
||||
}
|
||||
|
||||
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
||||
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
||||
var matched []Tensor
|
||||
for i := range merges {
|
||||
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
||||
matched, _ := path.Match(merges[i].pattern, t.Name())
|
||||
return matched
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
Kind: matched[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
||||
WriterTo: mergeGroup(matched),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out, unmatched
|
||||
}
|
||||
|
||||
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
||||
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
||||
for _, e := range s {
|
||||
if fn(e) {
|
||||
matched = append(matched, e)
|
||||
} else {
|
||||
unmatched = append(unmatched, e)
|
||||
}
|
||||
}
|
||||
|
||||
return matched, unmatched
|
||||
}
|
||||
|
||||
type mergeGroup []Tensor
|
||||
|
||||
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||
for _, t := range g {
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,11 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -260,11 +259,8 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||
tokens[token.ID] = token
|
||||
}
|
||||
|
||||
keys := maps.Keys(tokens)
|
||||
slices.Sort(keys)
|
||||
|
||||
v := Vocabulary{Model: "gpt2"}
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||
token := tokens[k]
|
||||
v.Tokens = append(v.Tokens, token.Content)
|
||||
v.Scores = append(v.Scores, float32(token.ID))
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
|
||||
func rocmLibUsable(libDir string) bool {
|
||||
slog.Debug("evaluating potential rocm lib dir " + libDir)
|
||||
for _, g := range ROCmLibGlobs {
|
||||
res, _ := filepath.Glob(filepath.Join(libDir, g))
|
||||
if len(res) == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func GetSupportedGFX(libDir string) ([]string, error) {
|
||||
var ret []string
|
||||
files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, file := range files {
|
||||
ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat"))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func commonAMDValidateLibDir() (string, error) {
|
||||
// Favor our bundled version
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
if hipPath != "" {
|
||||
hipLibDir := filepath.Join(hipPath, "bin")
|
||||
if rocmLibUsable(hipLibDir) {
|
||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||
return hipLibDir, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Scan the LD_LIBRARY_PATH or PATH
|
||||
pathEnv := "LD_LIBRARY_PATH"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathEnv = "PATH"
|
||||
}
|
||||
|
||||
paths := os.Getenv(pathEnv)
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
d, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if rocmLibUsable(d) {
|
||||
return d, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Well known location(s)
|
||||
for _, path := range RocmStandardLocations {
|
||||
if rocmLibUsable(path) {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
hipSuccess = 0
|
||||
hipErrorNoDevice = 100
|
||||
)
|
||||
|
||||
type hipDevicePropMinimal struct {
|
||||
Name [256]byte
|
||||
unused1 [140]byte
|
||||
GcnArchName [256]byte // gfx####
|
||||
iGPU int // Doesn't seem to actually report correctly
|
||||
unused2 [128]byte
|
||||
}
|
||||
|
||||
// Wrap the amdhip64.dll library for GPU discovery
|
||||
type HipLib struct {
|
||||
dll windows.Handle
|
||||
hipGetDeviceCount uintptr
|
||||
hipGetDeviceProperties uintptr
|
||||
hipMemGetInfo uintptr
|
||||
hipSetDevice uintptr
|
||||
hipDriverGetVersion uintptr
|
||||
}
|
||||
|
||||
func NewHipLib() (*HipLib, error) {
|
||||
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
|
||||
h, err := windows.LoadLibrary("amdhip64_6.dll")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
||||
}
|
||||
hl := &HipLib{}
|
||||
hl.dll = h
|
||||
hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hl, nil
|
||||
}
|
||||
|
||||
// The hip library only evaluates the ROCR_VISIBLE_DEVICES variable at startup
|
||||
// so we have to unload/reset the library after we do our initial discovery
|
||||
// to make sure our updates to that variable are processed by llama.cpp
|
||||
func (hl *HipLib) Release() {
|
||||
err := windows.FreeLibrary(hl.dll)
|
||||
if err != nil {
|
||||
slog.Warn("failed to unload amdhip64.dll", "error", err)
|
||||
}
|
||||
hl.dll = 0
|
||||
}
|
||||
|
||||
func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
if hl.dll == 0 {
|
||||
return 0, 0, errors.New("dll has been unloaded")
|
||||
}
|
||||
var version int
|
||||
status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
|
||||
if status != hipSuccess {
|
||||
return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
|
||||
}
|
||||
|
||||
slog.Debug("hipDriverGetVersion", "version", version)
|
||||
driverMajor = version / 10000000
|
||||
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
||||
func (hl *HipLib) HipGetDeviceCount() int {
|
||||
if hl.dll == 0 {
|
||||
slog.Error("dll has been unloaded")
|
||||
return 0
|
||||
}
|
||||
var count int
|
||||
status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
|
||||
if status == hipErrorNoDevice {
|
||||
slog.Info("AMD ROCm reports no devices found")
|
||||
return 0
|
||||
}
|
||||
if status != hipSuccess {
|
||||
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (hl *HipLib) HipSetDevice(device int) error {
|
||||
if hl.dll == 0 {
|
||||
return errors.New("dll has been unloaded")
|
||||
}
|
||||
status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
|
||||
if status != hipSuccess {
|
||||
return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
|
||||
if hl.dll == 0 {
|
||||
return nil, errors.New("dll has been unloaded")
|
||||
}
|
||||
var props hipDevicePropMinimal
|
||||
status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
|
||||
if status != hipSuccess {
|
||||
return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
|
||||
}
|
||||
return &props, nil
|
||||
}
|
||||
|
||||
// free, total, err
|
||||
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
|
||||
if hl.dll == 0 {
|
||||
return 0, 0, errors.New("dll has been unloaded")
|
||||
}
|
||||
var totalMemory uint64
|
||||
var freeMemory uint64
|
||||
status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
|
||||
if status != hipSuccess {
|
||||
return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
|
||||
}
|
||||
return freeMemory, totalMemory, nil
|
||||
}
|
||||
@@ -1,538 +0,0 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
// Discovery logic for AMD/ROCm GPUs
|
||||
|
||||
const (
|
||||
DriverVersionFile = "/sys/module/amdgpu/version"
|
||||
AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
|
||||
GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
|
||||
|
||||
// Prefix with the node dir
|
||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||
|
||||
// Direct Rendering Manager sysfs location
|
||||
DRMDeviceDirGlob = "/sys/class/drm/card*/device"
|
||||
DRMTotalMemoryFile = "mem_info_vram_total"
|
||||
DRMUsedMemoryFile = "mem_info_vram_used"
|
||||
|
||||
// In hex; properties file is in decimal
|
||||
DRMUniqueIDFile = "unique_id"
|
||||
DRMVendorFile = "vendor"
|
||||
DRMDeviceFile = "device"
|
||||
)
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
|
||||
)
|
||||
|
||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||
// Only called once during bootstrap
|
||||
func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
resp := []RocmGPUInfo{}
|
||||
if !AMDDetected() {
|
||||
return resp, fmt.Errorf("AMD GPUs not detected")
|
||||
}
|
||||
|
||||
// Opportunistic logging of driver version to aid in troubleshooting
|
||||
driverMajor, driverMinor, err := AMDDriverVersion()
|
||||
if err != nil {
|
||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||
}
|
||||
|
||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||
var visibleDevices []string
|
||||
hipVD := envconfig.HipVisibleDevices() // zero based index only
|
||||
rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID
|
||||
gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
|
||||
switch {
|
||||
case rocrVD != "":
|
||||
visibleDevices = strings.Split(rocrVD, ",")
|
||||
case hipVD != "":
|
||||
visibleDevices = strings.Split(hipVD, ",")
|
||||
case gpuDO != "":
|
||||
visibleDevices = strings.Split(gpuDO, ",")
|
||||
}
|
||||
|
||||
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
||||
var supported []string
|
||||
var libDir string
|
||||
|
||||
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
// /sys/class/kfd/kfd/topology/nodes/<number>/properties
|
||||
a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("parse err", "error", err, "match", matches[i])
|
||||
return false
|
||||
}
|
||||
b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("parse err", "error", err, "match", matches[i])
|
||||
return false
|
||||
}
|
||||
return a < b
|
||||
})
|
||||
gpuCount := 0
|
||||
for _, match := range matches {
|
||||
slog.Debug("evaluating amdgpu node " + match)
|
||||
fp, err := os.Open(match)
|
||||
if err != nil {
|
||||
slog.Debug("failed to open sysfs node", "file", match, "error", err)
|
||||
continue
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
scanner := bufio.NewScanner(fp)
|
||||
isCPU := false
|
||||
var major, minor, patch uint64
|
||||
var vendor, device, uniqueID uint64
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
||||
if strings.HasPrefix(line, "gfx_target_version") {
|
||||
ver := strings.Fields(line)
|
||||
|
||||
// Detect CPUs
|
||||
if len(ver) == 2 && ver[1] == "0" {
|
||||
slog.Debug("detected CPU " + match)
|
||||
isCPU = true
|
||||
break
|
||||
}
|
||||
|
||||
if len(ver) != 2 || len(ver[1]) < 5 {
|
||||
slog.Warn("malformed "+match, "gfx_target_version", line)
|
||||
// If this winds up being a CPU, our offsets may be wrong
|
||||
continue
|
||||
}
|
||||
l := len(ver[1])
|
||||
var err1, err2, err3 error
|
||||
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||||
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||||
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
slog.Debug("malformed int " + line)
|
||||
continue
|
||||
}
|
||||
} else if strings.HasPrefix(line, "vendor_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "vendor_id", line)
|
||||
continue
|
||||
}
|
||||
vendor, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "vendor_id", line, "error", err)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "device_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "device_id", line)
|
||||
continue
|
||||
}
|
||||
device, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "device_id", line, "error", err)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "unique_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "unique_id", line)
|
||||
continue
|
||||
}
|
||||
uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "unique_id", line, "error", err)
|
||||
}
|
||||
}
|
||||
// TODO - any other properties we want to extract and record?
|
||||
// vendor_id + device_id -> pci lookup for "Name"
|
||||
// Other metrics that may help us understand relative performance between multiple GPUs
|
||||
}
|
||||
|
||||
// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
|
||||
// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
|
||||
// do reliably report VRAM usage.
|
||||
|
||||
if isCPU {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip over any GPUs that are masked
|
||||
if major == 0 && minor == 0 && patch == 0 {
|
||||
slog.Debug("skipping gpu with gfx000")
|
||||
continue
|
||||
}
|
||||
|
||||
// Keep track of numeric IDs based on valid GPUs
|
||||
gpuID := gpuCount
|
||||
gpuCount += 1
|
||||
|
||||
// Look up the memory for the current node
|
||||
totalMemory := uint64(0)
|
||||
usedMemory := uint64(0)
|
||||
var usedFile string
|
||||
mapping := []struct {
|
||||
id uint64
|
||||
filename string
|
||||
}{
|
||||
{vendor, DRMVendorFile},
|
||||
{device, DRMDeviceFile},
|
||||
{uniqueID, DRMUniqueIDFile}, // Not all devices will report this
|
||||
}
|
||||
slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
|
||||
// Map over to DRM location to find the total/free memory
|
||||
drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
|
||||
for _, devDir := range drmMatches {
|
||||
matched := true
|
||||
for _, m := range mapping {
|
||||
if m.id == 0 {
|
||||
// Null ID means it didn't populate, so we can't use it to match
|
||||
continue
|
||||
}
|
||||
filename := filepath.Join(devDir, m.filename)
|
||||
buf, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
slog.Debug("failed to read sysfs node", "file", filename, "error", err)
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
// values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
|
||||
cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
if cmp != m.id {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found the matching DRM directory
|
||||
slog.Debug("matched", "amdgpu", match, "drm", devDir)
|
||||
totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
|
||||
buf, err := os.ReadFile(totalFile)
|
||||
if err != nil {
|
||||
slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
|
||||
break
|
||||
}
|
||||
totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
|
||||
break
|
||||
}
|
||||
|
||||
usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
|
||||
usedMemory, err = getFreeMemory(usedFile)
|
||||
if err != nil {
|
||||
slog.Debug("failed to update used memory", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var name string
|
||||
// TODO - PCI ID lookup
|
||||
if vendor > 0 && device > 0 {
|
||||
name = fmt.Sprintf("%04x:%04x", vendor, device)
|
||||
}
|
||||
|
||||
// Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong
|
||||
var ID string
|
||||
if uniqueID != 0 {
|
||||
ID = fmt.Sprintf("GPU-%016x", uniqueID)
|
||||
} else {
|
||||
ID = strconv.Itoa(gpuID)
|
||||
}
|
||||
|
||||
gpuInfo := RocmGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: ID,
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
},
|
||||
usedFilepath: usedFile,
|
||||
index: gpuID,
|
||||
}
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if totalMemory < IGPUMemLimit {
|
||||
reason := "unsupported Radeon iGPU detected skipping"
|
||||
slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
continue
|
||||
}
|
||||
minVer, err := strconv.Atoi(RocmComputeMajorMin)
|
||||
if err != nil {
|
||||
slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
|
||||
}
|
||||
if int(major) < minVer {
|
||||
reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
|
||||
slog.Warn(reason, "gpu", gpuID)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
reason := "filtering out device per user request"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||
// even if the user overrides, we still need to validate the library
|
||||
if libDir == "" {
|
||||
libDir, err = AMDValidateLibDir()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to verify rocm library: %w", err)
|
||||
slog.Warn(err.Error())
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
gpuInfo.DependencyPath = []string{libDir}
|
||||
|
||||
if gfxOverride == "" {
|
||||
// Only load supported list once
|
||||
if len(supported) == 0 {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to lookup supported GFX types: %w", err)
|
||||
slog.Warn(err.Error())
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("rocm supported GPUs", "types", supported)
|
||||
}
|
||||
gfx := gpuInfo.Compute
|
||||
if !slices.Contains[[]string, string](supported, gfx) {
|
||||
reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported)
|
||||
slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
continue
|
||||
} else {
|
||||
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
|
||||
}
|
||||
} else {
|
||||
slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
|
||||
}
|
||||
|
||||
// Check for env var workarounds
|
||||
if name == "1002:687f" { // Vega RX 56
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
if len(resp) == 0 {
|
||||
err := fmt.Errorf("no compatible amdgpu devices detected")
|
||||
slog.Info(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if err := verifyKFDDriverAccess(); err != nil {
|
||||
err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err)
|
||||
slog.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
||||
func AMDDetected() bool {
|
||||
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
|
||||
sysfsDir := filepath.Dir(DriverVersionFile)
|
||||
_, err := os.Stat(sysfsDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
||||
return false
|
||||
} else if err != nil {
|
||||
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
||||
// failing that, tell the user how to download it on their own
|
||||
func AMDValidateLibDir() (string, error) {
|
||||
libDir, err := commonAMDValidateLibDir()
|
||||
if err == nil {
|
||||
return libDir, nil
|
||||
}
|
||||
|
||||
// Well known ollama installer path
|
||||
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
||||
if rocmLibUsable(installedRocmDir) {
|
||||
return installedRocmDir, nil
|
||||
}
|
||||
|
||||
// If we still haven't found a usable rocm, the user will have to install it on their own
|
||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
|
||||
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
_, err = os.Stat(DriverVersionFile)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
|
||||
}
|
||||
fp, err := os.Open(DriverVersionFile)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer fp.Close()
|
||||
verString, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
pattern := `\A(\d+)\.(\d+).*`
|
||||
regex := regexp.MustCompile(pattern)
|
||||
match := regex.FindStringSubmatch(string(verString))
|
||||
if len(match) < 2 {
|
||||
return 0, 0, fmt.Errorf("malformed version string %s", string(verString))
|
||||
}
|
||||
driverMajor, err = strconv.Atoi(match[1])
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
driverMinor, err = strconv.Atoi(match[2])
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
||||
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
if len(gpus) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := range gpus {
|
||||
usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
|
||||
gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getFreeMemory(usedFile string) (uint64, error) {
|
||||
buf, err := os.ReadFile(usedFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
|
||||
}
|
||||
usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
|
||||
return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
|
||||
}
|
||||
return usedMemory, nil
|
||||
}
|
||||
|
||||
func verifyKFDDriverAccess() error {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err)
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// Container runtime failure?
|
||||
return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'")
|
||||
}
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
|
||||
iGPUName = "AMD Radeon(TM) Graphics"
|
||||
)
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
// Only called once during bootstrap
|
||||
func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
resp := []RocmGPUInfo{}
|
||||
hl, err := NewHipLib()
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
if err != nil {
|
||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||
slog.Debug("error looking up amd driver version", "error", err)
|
||||
}
|
||||
|
||||
// Note: the HIP library automatically handles subsetting to any *_VISIBLE_DEVICES the user specified
|
||||
count := hl.HipGetDeviceCount()
|
||||
if count == 0 {
|
||||
err := fmt.Errorf("no compatible amdgpu devices detected")
|
||||
slog.Info(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
libDir, err := AMDValidateLibDir()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to verify rocm library: %w", err)
|
||||
slog.Warn(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var supported []string
|
||||
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
||||
if gfxOverride == "" {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to lookup supported GFX types: %w", err)
|
||||
slog.Warn(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
|
||||
}
|
||||
|
||||
slog.Debug("detected hip devices", "count", count)
|
||||
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||
for i := range count {
|
||||
err = hl.HipSetDevice(i)
|
||||
if err != nil {
|
||||
slog.Warn("set device", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
props, err := hl.HipGetDeviceProperties(i)
|
||||
if err != nil {
|
||||
slog.Warn("get properties", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
n := bytes.IndexByte(props.Name[:], 0)
|
||||
name := string(props.Name[:n])
|
||||
// TODO is UUID actually populated on windows?
|
||||
// Can luid be used on windows for setting visible devices (and is it actually set?)
|
||||
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
||||
gfx := string(props.GcnArchName[:n])
|
||||
slog.Debug("hip device", "id", i, "name", name, "gfx", gfx)
|
||||
// slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
||||
// TODO Why isn't props.iGPU accurate!?
|
||||
|
||||
freeMemory, totalMemory, err := hl.HipMemGetInfo()
|
||||
if err != nil {
|
||||
slog.Warn("get mem info", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
gpuInfo := RocmGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: freeMemory,
|
||||
},
|
||||
// Free memory reporting on Windows is not reliable until we bump to ROCm v6.2
|
||||
UnreliableFreeMemory: true,
|
||||
|
||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||
DependencyPath: []string{libDir},
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if strings.EqualFold(name, iGPUName) || totalMemory < IGPUMemLimit {
|
||||
reason := "unsupported Radeon iGPU detected skipping"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip off Target Features when comparing
|
||||
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
||||
reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported)
|
||||
slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
// HSA_OVERRIDE_GFX_VERSION not supported on windows
|
||||
continue
|
||||
} else {
|
||||
slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx)
|
||||
}
|
||||
|
||||
slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func AMDValidateLibDir() (string, error) {
|
||||
libDir, err := commonAMDValidateLibDir()
|
||||
if err == nil {
|
||||
return libDir, nil
|
||||
}
|
||||
|
||||
// Installer payload (if we're running from some other location)
|
||||
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
|
||||
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
if len(gpus) == 0 {
|
||||
return nil
|
||||
}
|
||||
hl, err := NewHipLib()
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return err
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
for i := range gpus {
|
||||
err := hl.HipSetDevice(gpus[i].index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
freeMemory, _, err := hl.HipMemGetInfo()
|
||||
if err != nil {
|
||||
slog.Warn("get mem info", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory))
|
||||
gpus[i].FreeMemory = freeMemory
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func IsNUMA() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
// numa support in llama.cpp is linux only
|
||||
return false
|
||||
}
|
||||
ids := map[string]any{}
|
||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||
for _, packageId := range packageIds {
|
||||
id, err := os.ReadFile(packageId)
|
||||
if err == nil {
|
||||
ids[strings.TrimSpace(string(id))] = struct{}{}
|
||||
}
|
||||
}
|
||||
return len(ids) > 1
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
@@ -13,47 +15,6 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var CudartGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libcudart.so*",
|
||||
"/usr/lib/wsl/lib/libcudart.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcudart.so*",
|
||||
"/opt/cuda/lib64/libcudart.so*",
|
||||
"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libcudart.so*",
|
||||
"/usr/local/cuda/lib*/libcudart.so*",
|
||||
"/usr/lib*/libcudart.so*",
|
||||
"/usr/local/lib*/libcudart.so*",
|
||||
}
|
||||
|
||||
var NvmlGlobs = []string{}
|
||||
|
||||
var NvcudaGlobs = []string{
|
||||
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/libcuda.so*",
|
||||
"/usr/lib/wsl/lib/libcuda.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcuda.so*",
|
||||
"/opt/cuda/lib*/libcuda.so*",
|
||||
"/usr/local/cuda/lib*/libcuda.so*",
|
||||
"/usr/lib*/libcuda.so*",
|
||||
"/usr/local/lib*/libcuda.so*",
|
||||
}
|
||||
|
||||
var OneapiGlobs = []string{
|
||||
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
|
||||
"/usr/lib*/libze_intel_gpu.so*",
|
||||
}
|
||||
|
||||
var (
|
||||
CudartMgmtName = "libcudart.so*"
|
||||
NvcudaMgmtName = "libcuda.so*"
|
||||
NvmlMgmtName = "" // not currently wired on linux
|
||||
OneapiMgmtName = "libze_intel_gpu.so*"
|
||||
)
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached, freeSwap uint64
|
||||
@@ -106,16 +67,17 @@ type linuxCpuInfo struct {
|
||||
CoreID string `cpuinfo:"core id"`
|
||||
}
|
||||
|
||||
func GetCPUDetails() ([]CPU, error) {
|
||||
func GetCPUDetails() []CPU {
|
||||
file, err := os.Open(CpuInfoFilename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
slog.Warn("failed to get CPU details", "error", err)
|
||||
return nil
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
func linuxCPUDetails(file io.Reader) []CPU {
|
||||
reColumns := regexp.MustCompile("\t+: ")
|
||||
scanner := bufio.NewScanner(file)
|
||||
cpuInfos := []linuxCpuInfo{}
|
||||
@@ -194,5 +156,17 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
for _, k := range keys {
|
||||
result = append(result, *socketByID[k])
|
||||
}
|
||||
return result, nil
|
||||
return result
|
||||
}
|
||||
|
||||
func IsNUMA() bool {
|
||||
ids := map[string]any{}
|
||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||
for _, packageId := range packageIds {
|
||||
id, err := os.ReadFile(packageId)
|
||||
if err == nil {
|
||||
ids[strings.TrimSpace(string(id))] = struct{}{}
|
||||
}
|
||||
}
|
||||
return len(ids) > 1
|
||||
}
|
||||
@@ -2062,18 +2062,9 @@ power management:
|
||||
for k, v := range testCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
buf := bytes.NewBufferString(v.input)
|
||||
cpus, err := linuxCPUDetails(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cpus := linuxCPUDetails(buf)
|
||||
|
||||
slog.Info("example", "scenario", k, "cpus", cpus)
|
||||
si := SystemInfo{
|
||||
System: CPUInfo{
|
||||
CPUs: cpus,
|
||||
},
|
||||
}
|
||||
threadCount := si.GetOptimalThreadCount()
|
||||
if len(v.expCPUs) != len(cpus) {
|
||||
t.Fatalf("incorrect number of sockets: expected:%v got:%v", v.expCPUs, cpus)
|
||||
}
|
||||
@@ -2088,10 +2079,6 @@ power management:
|
||||
t.Fatalf("incorrect number of threads: expected:%v got:%v", v.expCPUs[i], c)
|
||||
}
|
||||
}
|
||||
|
||||
if threadCount != v.expThreadCount {
|
||||
t.Fatalf("incorrect thread count expected:%d got:%d", v.expThreadCount, threadCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -26,29 +26,6 @@ var (
|
||||
GetLogicalProcessorInformationEx = k32.NewProc("GetLogicalProcessorInformationEx")
|
||||
)
|
||||
|
||||
var CudartGlobs = []string{
|
||||
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
||||
}
|
||||
|
||||
var NvmlGlobs = []string{
|
||||
"c:\\Windows\\System32\\nvml.dll",
|
||||
}
|
||||
|
||||
var NvcudaGlobs = []string{
|
||||
"c:\\windows\\system*\\nvcuda.dll",
|
||||
}
|
||||
|
||||
var OneapiGlobs = []string{
|
||||
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
||||
}
|
||||
|
||||
var (
|
||||
CudartMgmtName = "cudart64_*.dll"
|
||||
NvcudaMgmtName = "nvcuda.dll"
|
||||
NvmlMgmtName = "nvml.dll"
|
||||
OneapiMgmtName = "ze_intel_gpu64.dll"
|
||||
)
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
|
||||
r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus)))
|
||||
@@ -122,27 +99,22 @@ func (pkg *winPackage) IsMember(target *GROUP_AFFINITY) bool {
|
||||
}
|
||||
|
||||
func getLogicalProcessorInformationEx() ([]byte, error) {
|
||||
buf := make([]byte, 1)
|
||||
buf := make([]byte, 1024)
|
||||
bufSize := len(buf)
|
||||
ret, _, err := GetLogicalProcessorInformationEx.Call(
|
||||
uintptr(RelationAll),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil, fmt.Errorf("failed to determine size info ret:%d %w", ret, err)
|
||||
var err error
|
||||
for range 3 {
|
||||
var ret uintptr
|
||||
ret, _, err = GetLogicalProcessorInformationEx.Call(
|
||||
uintptr(RelationAll),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
)
|
||||
if ret == 1 && bufSize <= len(buf) {
|
||||
return buf, nil
|
||||
}
|
||||
buf = make([]byte, bufSize)
|
||||
}
|
||||
|
||||
buf = make([]byte, bufSize)
|
||||
ret, _, err = GetLogicalProcessorInformationEx.Call(
|
||||
uintptr(RelationAll),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("failed to gather processor information ret:%d buflen:%d %w", ret, bufSize, err)
|
||||
}
|
||||
return buf, nil
|
||||
return nil, fmt.Errorf("unable to determine CPU details: %w", err)
|
||||
}
|
||||
|
||||
func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage {
|
||||
@@ -217,10 +189,11 @@ func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage {
|
||||
return packages
|
||||
}
|
||||
|
||||
func GetCPUDetails() ([]CPU, error) {
|
||||
func GetCPUDetails() []CPU {
|
||||
buf, err := getLogicalProcessorInformationEx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
slog.Warn("failed to get CPU details", "error", err)
|
||||
return nil
|
||||
}
|
||||
packages := processSystemLogicalProcessorInforationList(buf)
|
||||
cpus := make([]CPU, len(packages))
|
||||
@@ -230,5 +203,10 @@ func GetCPUDetails() ([]CPU, error) {
|
||||
cpus[i].EfficiencyCoreCount = pkg.efficiencyCoreCount
|
||||
cpus[i].ThreadCount = pkg.threadCount
|
||||
}
|
||||
return cpus, nil
|
||||
return cpus
|
||||
}
|
||||
|
||||
func IsNUMA() bool {
|
||||
// numa support in ggml is linux only
|
||||
return false
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "cuda" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
if len(ver) > 0 {
|
||||
return "jetpack" + ver[0]
|
||||
}
|
||||
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
|
||||
r := regexp.MustCompile(` R(\d+) `)
|
||||
m := r.FindSubmatch(data)
|
||||
if len(m) != 2 {
|
||||
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
|
||||
} else {
|
||||
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
|
||||
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
|
||||
// https://developer.nvidia.com/embedded/jetpack-archive
|
||||
switch l4t {
|
||||
case 35:
|
||||
return "jetpack5"
|
||||
case 36:
|
||||
return "jetpack6"
|
||||
default:
|
||||
slog.Info("unsupported L4T version", "nv_tegra_release", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
return "v11"
|
||||
}
|
||||
return "v12"
|
||||
}
|
||||
751
discover/gpu.go
751
discover/gpu.go
@@ -1,718 +1,73 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
/*
|
||||
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
|
||||
#cgo windows LDFLAGS: -lpthread
|
||||
|
||||
#include "gpu_info.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type cudaHandles struct {
|
||||
deviceCount int
|
||||
cudart *C.cudart_handle_t
|
||||
nvcuda *C.nvcuda_handle_t
|
||||
nvml *C.nvml_handle_t
|
||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
// GetSystemInfo returns the last cached state of the GPUs on the system
|
||||
func GetSystemInfo() ml.SystemInfo {
|
||||
memInfo, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
}
|
||||
var threadCount int
|
||||
cpus := GetCPUDetails()
|
||||
for _, c := range cpus {
|
||||
threadCount += c.CoreCount - c.EfficiencyCoreCount
|
||||
}
|
||||
|
||||
if threadCount == 0 {
|
||||
// Fall back to Go's num CPU
|
||||
threadCount = runtime.NumCPU()
|
||||
}
|
||||
|
||||
return ml.SystemInfo{
|
||||
ThreadCount: threadCount,
|
||||
TotalMemory: memInfo.TotalMemory,
|
||||
FreeMemory: memInfo.FreeMemory,
|
||||
FreeSwap: memInfo.FreeSwap,
|
||||
}
|
||||
}
|
||||
|
||||
type oneapiHandles struct {
|
||||
oneapi *C.oneapi_handle_t
|
||||
deviceCount int
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
// TODO OneAPI minimum memory
|
||||
)
|
||||
|
||||
var (
|
||||
gpuMutex sync.Mutex
|
||||
bootstrapped bool
|
||||
cpus []CPUInfo
|
||||
cudaGPUs []CudaGPUInfo
|
||||
nvcudaLibPath string
|
||||
cudartLibPath string
|
||||
oneapiLibPath string
|
||||
nvmlLibPath string
|
||||
rocmGPUs []RocmGPUInfo
|
||||
oneapiGPUs []OneapiGPUInfo
|
||||
|
||||
// If any discovered GPUs are incompatible, report why
|
||||
unsupportedGPUs []UnsupportedGPUInfo
|
||||
|
||||
// Keep track of errors during bootstrapping so that if GPUs are missing
|
||||
// they expected to be present this may explain why
|
||||
bootstrapErrors []error
|
||||
)
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
// (string values used to allow ldflags overrides at build time)
|
||||
var (
|
||||
CudaComputeMajorMin = "5"
|
||||
CudaComputeMinorMin = "0"
|
||||
)
|
||||
|
||||
var RocmComputeMajorMin = "9"
|
||||
|
||||
// TODO find a better way to detect iGPU instead of minimum memory
|
||||
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initCudaHandles() *cudaHandles {
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
cHandles := &cudaHandles{}
|
||||
// Short Circuit if we already know which library to use
|
||||
// ignore bootstrap errors in this case since we already recorded them
|
||||
if nvmlLibPath != "" {
|
||||
cHandles.nvml, _, _ = loadNVMLMgmt([]string{nvmlLibPath})
|
||||
return cHandles
|
||||
}
|
||||
if nvcudaLibPath != "" {
|
||||
cHandles.deviceCount, cHandles.nvcuda, _, _ = loadNVCUDAMgmt([]string{nvcudaLibPath})
|
||||
return cHandles
|
||||
}
|
||||
if cudartLibPath != "" {
|
||||
cHandles.deviceCount, cHandles.cudart, _, _ = loadCUDARTMgmt([]string{cudartLibPath})
|
||||
return cHandles
|
||||
}
|
||||
|
||||
slog.Debug("searching for GPU discovery libraries for NVIDIA")
|
||||
var cudartMgmtPatterns []string
|
||||
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtPatterns := NvcudaGlobs
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName))
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
|
||||
|
||||
if len(NvmlGlobs) > 0 {
|
||||
nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs)
|
||||
if len(nvmlLibPaths) > 0 {
|
||||
nvml, libPath, err := loadNVMLMgmt(nvmlLibPaths)
|
||||
if nvml != nil {
|
||||
slog.Debug("nvidia-ml loaded", "library", libPath)
|
||||
cHandles.nvml = nvml
|
||||
nvmlLibPath = libPath
|
||||
func cudaJetpack() string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
if len(ver) > 0 {
|
||||
return "jetpack" + ver[0]
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns)
|
||||
if len(nvcudaLibPaths) > 0 {
|
||||
deviceCount, nvcuda, libPath, err := loadNVCUDAMgmt(nvcudaLibPaths)
|
||||
if nvcuda != nil {
|
||||
slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
|
||||
cHandles.nvcuda = nvcuda
|
||||
cHandles.deviceCount = deviceCount
|
||||
nvcudaLibPath = libPath
|
||||
return cHandles
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
}
|
||||
|
||||
cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns)
|
||||
if len(cudartLibPaths) > 0 {
|
||||
deviceCount, cudart, libPath, err := loadCUDARTMgmt(cudartLibPaths)
|
||||
if cudart != nil {
|
||||
slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
|
||||
cHandles.cudart = cudart
|
||||
cHandles.deviceCount = deviceCount
|
||||
cudartLibPath = libPath
|
||||
return cHandles
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
}
|
||||
|
||||
return cHandles
|
||||
}
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initOneAPIHandles() *oneapiHandles {
|
||||
oHandles := &oneapiHandles{}
|
||||
|
||||
// Short Circuit if we already know which library to use
|
||||
// ignore bootstrap errors in this case since we already recorded them
|
||||
if oneapiLibPath != "" {
|
||||
oHandles.deviceCount, oHandles.oneapi, _, _ = loadOneapiMgmt([]string{oneapiLibPath})
|
||||
return oHandles
|
||||
}
|
||||
|
||||
oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs)
|
||||
if len(oneapiLibPaths) > 0 {
|
||||
var err error
|
||||
oHandles.deviceCount, oHandles.oneapi, oneapiLibPath, err = loadOneapiMgmt(oneapiLibPaths)
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
}
|
||||
|
||||
return oHandles
|
||||
}
|
||||
|
||||
func GetCPUInfo() GpuInfoList {
|
||||
gpuMutex.Lock()
|
||||
if !bootstrapped {
|
||||
gpuMutex.Unlock()
|
||||
GetGPUInfo()
|
||||
} else {
|
||||
gpuMutex.Unlock()
|
||||
}
|
||||
return GpuInfoList{cpus[0].GpuInfo}
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
needRefresh := true
|
||||
var cHandles *cudaHandles
|
||||
var oHandles *oneapiHandles
|
||||
defer func() {
|
||||
if cHandles != nil {
|
||||
if cHandles.cudart != nil {
|
||||
C.cudart_release(*cHandles.cudart)
|
||||
}
|
||||
if cHandles.nvcuda != nil {
|
||||
C.nvcuda_release(*cHandles.nvcuda)
|
||||
}
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_release(*cHandles.nvml)
|
||||
}
|
||||
}
|
||||
if oHandles != nil {
|
||||
if oHandles.oneapi != nil {
|
||||
// TODO - is this needed?
|
||||
C.oneapi_release(*oHandles.oneapi)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !bootstrapped {
|
||||
slog.Info("looking for compatible GPUs")
|
||||
cudaComputeMajorMin, err := strconv.Atoi(CudaComputeMajorMin)
|
||||
if err != nil {
|
||||
slog.Error("invalid CudaComputeMajorMin setting", "value", CudaComputeMajorMin, "error", err)
|
||||
}
|
||||
cudaComputeMinorMin, err := strconv.Atoi(CudaComputeMinorMin)
|
||||
if err != nil {
|
||||
slog.Error("invalid CudaComputeMinorMin setting", "value", CudaComputeMinorMin, "error", err)
|
||||
}
|
||||
bootstrapErrors = []error{}
|
||||
needRefresh = false
|
||||
var memInfo C.mem_info_t
|
||||
|
||||
mem, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
}
|
||||
|
||||
details, err := GetCPUDetails()
|
||||
if err != nil {
|
||||
slog.Warn("failed to lookup CPU details", "error", err)
|
||||
}
|
||||
cpus = []CPUInfo{
|
||||
{
|
||||
GpuInfo: GpuInfo{
|
||||
memInfo: mem,
|
||||
Library: "cpu",
|
||||
ID: "0",
|
||||
},
|
||||
CPUs: details,
|
||||
},
|
||||
}
|
||||
|
||||
// Load ALL libraries
|
||||
cHandles = initCudaHandles()
|
||||
|
||||
// NVIDIA
|
||||
for i := range cHandles.deviceCount {
|
||||
if cHandles.cudart != nil || cHandles.nvcuda != nil {
|
||||
gpuInfo := CudaGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "cuda",
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
var driverMajor int
|
||||
var driverMinor int
|
||||
if cHandles.cudart != nil {
|
||||
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
|
||||
} else {
|
||||
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
|
||||
driverMajor = int(cHandles.nvcuda.driver_major)
|
||||
driverMinor = int(cHandles.nvcuda.driver_minor)
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
continue
|
||||
}
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
|
||||
gpuInfo.computeMajor = int(memInfo.major)
|
||||
gpuInfo.computeMinor = int(memInfo.minor)
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
variant := cudaVariant(gpuInfo)
|
||||
|
||||
// Start with our bundled libraries
|
||||
if variant != "" {
|
||||
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||
if _, err := os.Stat(variantPath); err == nil {
|
||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
|
||||
}
|
||||
}
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.Variant = variant
|
||||
|
||||
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
|
||||
unsupportedGPUs = append(unsupportedGPUs,
|
||||
UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
})
|
||||
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||
continue
|
||||
}
|
||||
|
||||
// query the management library as well so we can record any skew between the two
|
||||
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||
if cHandles.nvml != nil {
|
||||
uuid := C.CString(gpuInfo.ID)
|
||||
defer C.free(unsafe.Pointer(uuid))
|
||||
C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||
slog.Info("detected OS VRAM overhead",
|
||||
"id", gpuInfo.ID,
|
||||
"library", gpuInfo.Library,
|
||||
"compute", gpuInfo.Compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||
"name", gpuInfo.Name,
|
||||
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// Intel
|
||||
if envconfig.IntelGPU() {
|
||||
oHandles = initOneAPIHandles()
|
||||
if oHandles != nil && oHandles.oneapi != nil {
|
||||
for d := range oHandles.oneapi.num_drivers {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
||||
continue
|
||||
}
|
||||
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
||||
for i := range devCount {
|
||||
gpuInfo := OneapiGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "oneapi",
|
||||
},
|
||||
driverIndex: int(d),
|
||||
gpuIndex: int(i),
|
||||
}
|
||||
// TODO - split bootstrapping from updating free memory
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.DependencyPath = []string{LibOllamaPath}
|
||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rocmGPUs, err = AMDGetGPUInfo()
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
bootstrapped = true
|
||||
if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 {
|
||||
slog.Info("no compatible GPUs were discovered")
|
||||
}
|
||||
|
||||
// TODO verify we have runners for the discovered GPUs, filter out any that aren't supported with good error messages
|
||||
}
|
||||
|
||||
// For detected GPUs, load library if not loaded
|
||||
|
||||
// Refresh free memory usage
|
||||
if needRefresh {
|
||||
mem, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
} else {
|
||||
slog.Debug("updating system memory data",
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(mem.TotalMemory),
|
||||
"free", format.HumanBytes2(mem.FreeMemory),
|
||||
"free_swap", format.HumanBytes2(mem.FreeSwap),
|
||||
),
|
||||
)
|
||||
cpus[0].FreeMemory = mem.FreeMemory
|
||||
cpus[0].FreeSwap = mem.FreeSwap
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
if cHandles == nil && len(cudaGPUs) > 0 {
|
||||
cHandles = initCudaHandles()
|
||||
}
|
||||
for i, gpu := range cudaGPUs {
|
||||
if cHandles.nvml != nil {
|
||||
uuid := C.CString(gpu.ID)
|
||||
defer C.free(unsafe.Pointer(uuid))
|
||||
C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
} else if cHandles.cudart != nil {
|
||||
C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
|
||||
} else if cHandles.nvcuda != nil {
|
||||
C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total)
|
||||
memInfo.used = memInfo.total - memInfo.free
|
||||
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
|
||||
r := regexp.MustCompile(` R(\d+) `)
|
||||
m := r.FindSubmatch(data)
|
||||
if len(m) != 2 {
|
||||
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
|
||||
} else {
|
||||
// shouldn't happen
|
||||
slog.Warn("no valid cuda library loaded to refresh vram usage")
|
||||
break
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
continue
|
||||
}
|
||||
if memInfo.free == 0 {
|
||||
slog.Warn("error looking up nvidia GPU memory")
|
||||
continue
|
||||
}
|
||||
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||
// When using the management library update based on recorded overhead
|
||||
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||
}
|
||||
slog.Debug("updating cuda memory data",
|
||||
"gpu", gpu.ID,
|
||||
"name", gpu.Name,
|
||||
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(uint64(memInfo.total)),
|
||||
"free", format.HumanBytes2(uint64(memInfo.free)),
|
||||
"used", format.HumanBytes2(uint64(memInfo.used)),
|
||||
),
|
||||
)
|
||||
cudaGPUs[i].FreeMemory = uint64(memInfo.free)
|
||||
}
|
||||
|
||||
if oHandles == nil && len(oneapiGPUs) > 0 {
|
||||
oHandles = initOneAPIHandles()
|
||||
}
|
||||
for i, gpu := range oneapiGPUs {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
|
||||
continue
|
||||
}
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
|
||||
}
|
||||
|
||||
err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
|
||||
if err != nil {
|
||||
slog.Debug("problem refreshing ROCm free memory", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
resp := []GpuInfo{}
|
||||
for _, gpu := range cudaGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
for _, gpu := range rocmGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
for _, gpu := range oneapiGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
if len(resp) == 0 {
|
||||
resp = append(resp, cpus[0].GpuInfo)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||
gpuLibPaths := []string{}
|
||||
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||
|
||||
// search our bundled libraries first
|
||||
patterns := []string{filepath.Join(LibOllamaPath, baseLibName)}
|
||||
|
||||
var ldPaths []string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator))
|
||||
case "linux":
|
||||
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator))
|
||||
}
|
||||
|
||||
// then search the system's LD_LIBRARY_PATH
|
||||
for _, p := range ldPaths {
|
||||
p, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, filepath.Join(p, baseLibName))
|
||||
}
|
||||
|
||||
// finally, search the default patterns provided by the caller
|
||||
patterns = append(patterns, defaultPatterns...)
|
||||
slog.Debug("gpu library search", "globs", patterns)
|
||||
for _, pattern := range patterns {
|
||||
// Nvidia PhysX known to return bogus results
|
||||
if strings.Contains(pattern, "PhysX") {
|
||||
slog.Debug("skipping PhysX cuda library path", "path", pattern)
|
||||
continue
|
||||
}
|
||||
// Ignore glob discovery errors
|
||||
matches, _ := filepath.Glob(pattern)
|
||||
for _, match := range matches {
|
||||
// Resolve any links so we don't try the same lib multiple times
|
||||
// and weed out any dups across globs
|
||||
libPath := match
|
||||
tmp := match
|
||||
var err error
|
||||
for ; err == nil; tmp, err = os.Readlink(libPath) {
|
||||
if !filepath.IsAbs(tmp) {
|
||||
tmp = filepath.Join(filepath.Dir(libPath), tmp)
|
||||
}
|
||||
libPath = tmp
|
||||
}
|
||||
new := true
|
||||
for _, cmp := range gpuLibPaths {
|
||||
if cmp == libPath {
|
||||
new = false
|
||||
break
|
||||
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
|
||||
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
|
||||
// https://developer.nvidia.com/embedded/jetpack-archive
|
||||
switch l4t {
|
||||
case 35:
|
||||
return "jetpack5"
|
||||
case 36:
|
||||
return "jetpack6"
|
||||
default:
|
||||
// Newer Jetson systems use the SBSU runtime
|
||||
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
if new {
|
||||
gpuLibPaths = append(gpuLibPaths, libPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
|
||||
return gpuLibPaths
|
||||
}
|
||||
|
||||
// Bootstrap the runtime library
|
||||
// Returns: num devices, handle, libPath, error
|
||||
func loadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string, error) {
|
||||
var resp C.cudart_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
var err error
|
||||
for _, libPath := range cudartLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.cudart_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err))
|
||||
slog.Debug(err.Error())
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
err = nil
|
||||
return int(resp.num_devices), &resp.ch, libPath, err
|
||||
}
|
||||
}
|
||||
return 0, nil, "", err
|
||||
}
|
||||
|
||||
// Bootstrap the driver library
|
||||
// Returns: num devices, handle, libPath, error
|
||||
func loadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string, error) {
|
||||
var resp C.nvcuda_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
var err error
|
||||
for _, libPath := range nvcudaLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.nvcuda_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
// Decide what log level based on the type of error message to help users understand why
|
||||
switch resp.cudaErr {
|
||||
case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH:
|
||||
err = fmt.Errorf("version mismatch between driver and cuda driver library - reboot or upgrade may be required: library %s", libPath)
|
||||
slog.Warn(err.Error())
|
||||
case C.CUDA_ERROR_NO_DEVICE:
|
||||
err = fmt.Errorf("no nvidia devices detected by library %s", libPath)
|
||||
slog.Info(err.Error())
|
||||
case C.CUDA_ERROR_UNKNOWN:
|
||||
err = fmt.Errorf("unknown error initializing cuda driver library %s: %s. see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information", libPath, C.GoString(resp.err))
|
||||
slog.Warn(err.Error())
|
||||
default:
|
||||
msg := C.GoString(resp.err)
|
||||
if strings.Contains(msg, "wrong ELF class") {
|
||||
slog.Debug("skipping 32bit library", "library", libPath)
|
||||
} else {
|
||||
err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err))
|
||||
slog.Info(err.Error())
|
||||
}
|
||||
}
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
err = nil
|
||||
return int(resp.num_devices), &resp.ch, libPath, err
|
||||
}
|
||||
}
|
||||
return 0, nil, "", err
|
||||
}
|
||||
|
||||
// Bootstrap the management library
|
||||
// Returns: handle, libPath, error
|
||||
func loadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string, error) {
|
||||
var resp C.nvml_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
var err error
|
||||
for _, libPath := range nvmlLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.nvml_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
err = fmt.Errorf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err))
|
||||
slog.Info(err.Error())
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
err = nil
|
||||
return &resp.ch, libPath, err
|
||||
}
|
||||
}
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// bootstrap the Intel GPU library
|
||||
// Returns: num devices, handle, libPath, error
|
||||
func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, error) {
|
||||
var resp C.oneapi_init_resp_t
|
||||
num_devices := 0
|
||||
resp.oh.verbose = getVerboseState()
|
||||
var err error
|
||||
for _, libPath := range oneapiLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.oneapi_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
err = fmt.Errorf("Unable to load oneAPI management library %s: %s", libPath, C.GoString(resp.err))
|
||||
slog.Debug(err.Error())
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
err = nil
|
||||
for i := range resp.oh.num_drivers {
|
||||
num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
|
||||
}
|
||||
return num_devices, &resp.oh, libPath, err
|
||||
}
|
||||
}
|
||||
return 0, nil, "", err
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if envconfig.LogLevel() < slog.LevelInfo {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
if len(l) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
switch l[0].Library {
|
||||
case "cuda":
|
||||
return cudaGetVisibleDevicesEnv(l)
|
||||
case "rocm":
|
||||
return rocmGetVisibleDevicesEnv(l)
|
||||
case "oneapi":
|
||||
return oneapiGetVisibleDevicesEnv(l)
|
||||
default:
|
||||
slog.Debug("no filter required for library " + l[0].Library)
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
gpus := GetGPUInfo()
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
discoveryErrors := []string{}
|
||||
for _, err := range bootstrapErrors {
|
||||
discoveryErrors = append(discoveryErrors, err.Error())
|
||||
}
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
gpus = []GpuInfo{}
|
||||
}
|
||||
|
||||
return SystemInfo{
|
||||
System: cpus[0],
|
||||
GPUs: gpus,
|
||||
UnsupportedGPUs: unsupportedGPUs,
|
||||
DiscoveryErrors: discoveryErrors,
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build darwin
|
||||
|
||||
package discover
|
||||
|
||||
/*
|
||||
@@ -11,7 +9,6 @@ import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -21,39 +18,6 @@ const (
|
||||
metalMinimumMemory = 512 * format.MebiByte
|
||||
)
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
mem, _ := GetCPUMem()
|
||||
if runtime.GOARCH == "amd64" {
|
||||
return []GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
memInfo: mem,
|
||||
},
|
||||
}
|
||||
}
|
||||
info := GpuInfo{
|
||||
Library: "metal",
|
||||
ID: "0",
|
||||
}
|
||||
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
|
||||
|
||||
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||
info.FreeMemory = info.TotalMemory
|
||||
|
||||
info.MinimumMemory = metalMinimumMemory
|
||||
return []GpuInfo{info}
|
||||
}
|
||||
|
||||
func GetCPUInfo() GpuInfoList {
|
||||
mem, _ := GetCPUMem()
|
||||
return []GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
memInfo: mem,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
@@ -62,13 +26,7 @@ func GetCPUMem() (memInfo, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
// No-op on darwin
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
mem, _ := GetCPUMem()
|
||||
func GetCPUDetails() []CPU {
|
||||
query := "hw.perflevel0.physicalcpu"
|
||||
perfCores, err := syscall.SysctlUint32(query)
|
||||
if err != nil {
|
||||
@@ -81,19 +39,16 @@ func GetSystemInfo() SystemInfo {
|
||||
query = "hw.logicalcpu"
|
||||
logicalCores, _ := syscall.SysctlUint32(query)
|
||||
|
||||
return SystemInfo{
|
||||
System: CPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
memInfo: mem,
|
||||
},
|
||||
CPUs: []CPU{
|
||||
{
|
||||
CoreCount: int(perfCores + efficiencyCores),
|
||||
EfficiencyCoreCount: int(efficiencyCores),
|
||||
ThreadCount: int(logicalCores),
|
||||
},
|
||||
},
|
||||
return []CPU{
|
||||
{
|
||||
CoreCount: int(perfCores + efficiencyCores),
|
||||
EfficiencyCoreCount: int(efficiencyCores),
|
||||
ThreadCount: int(logicalCores),
|
||||
},
|
||||
GPUs: GetGPUInfo(),
|
||||
}
|
||||
}
|
||||
|
||||
func IsNUMA() bool {
|
||||
// numa support in ggml is linux only
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_H__
|
||||
#define __GPU_INFO_H__
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#else
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
#define LOAD_ERR() ({\
|
||||
LPSTR messageBuffer = NULL; \
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \
|
||||
char *resp = strdup(messageBuffer); \
|
||||
LocalFree(messageBuffer); \
|
||||
resp; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef LOG
|
||||
#define LOG(verbose, ...) \
|
||||
do { \
|
||||
if (verbose) { \
|
||||
fprintf(stderr, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GPU_ID_LEN 64
|
||||
#define GPU_NAME_LEN 96
|
||||
|
||||
typedef struct mem_info {
|
||||
char *err; // If non-nill, caller responsible for freeing
|
||||
char gpu_id[GPU_ID_LEN];
|
||||
char gpu_name[GPU_NAME_LEN];
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
uint64_t used;
|
||||
|
||||
// Compute Capability
|
||||
int major;
|
||||
int minor;
|
||||
int patch;
|
||||
} mem_info_t;
|
||||
|
||||
void cpu_check_ram(mem_info_t *resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "gpu_info_cudart.h"
|
||||
#include "gpu_info_nvcuda.h"
|
||||
#include "gpu_info_nvml.h"
|
||||
#include "gpu_info_oneapi.h"
|
||||
|
||||
#endif // __GPU_INFO_H__
|
||||
#endif // __APPLE__
|
||||
@@ -1,184 +0,0 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_cudart.h"
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
cudartReturn_t ret;
|
||||
resp->err = NULL;
|
||||
resp->num_devices = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"cudaSetDevice", (void *)&resp->ch.cudaSetDevice},
|
||||
{"cudaDeviceSynchronize", (void *)&resp->ch.cudaDeviceSynchronize},
|
||||
{"cudaDeviceReset", (void *)&resp->ch.cudaDeviceReset},
|
||||
{"cudaMemGetInfo", (void *)&resp->ch.cudaMemGetInfo},
|
||||
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
||||
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
||||
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
||||
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(cudart_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", cudart_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
cudart_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!*(l[i].p)) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cudaSetDevice)(0);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
int version = 0;
|
||||
cudartDriverVersion_t driverVersion;
|
||||
driverVersion.major = 0;
|
||||
driverVersion.minor = 0;
|
||||
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.cudaDriverGetVersion)(&version);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret);
|
||||
} else {
|
||||
driverVersion.major = version / 1000;
|
||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
cudartMemory_t memInfo = {0,0,0};
|
||||
cudartReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("cudart handle isn't initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cudaSetDevice)(i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
cudaDeviceProp_t props;
|
||||
ret = (*h.cudaGetDeviceProperties)(&props, i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
} else {
|
||||
int allNull = 1;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (props.uuid.bytes[j] != 0) {
|
||||
allNull = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allNull != 0) {
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
} else {
|
||||
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
props.uuid.bytes[0],
|
||||
props.uuid.bytes[1],
|
||||
props.uuid.bytes[2],
|
||||
props.uuid.bytes[3],
|
||||
props.uuid.bytes[4],
|
||||
props.uuid.bytes[5],
|
||||
props.uuid.bytes[6],
|
||||
props.uuid.bytes[7],
|
||||
props.uuid.bytes[8],
|
||||
props.uuid.bytes[9],
|
||||
props.uuid.bytes[10],
|
||||
props.uuid.bytes[11],
|
||||
props.uuid.bytes[12],
|
||||
props.uuid.bytes[13],
|
||||
props.uuid.bytes[14],
|
||||
props.uuid.bytes[15]
|
||||
);
|
||||
}
|
||||
resp->major = props.major;
|
||||
resp->minor = props.minor;
|
||||
|
||||
// TODO add other useful properties from props
|
||||
}
|
||||
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
resp->used = memInfo.used;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
}
|
||||
|
||||
void cudart_release(cudart_handle_t h) {
|
||||
LOG(h.verbose, "releasing cudart library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -1,148 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_CUDART_H__
|
||||
#define __GPU_INFO_CUDART_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum cudartReturn_enum {
|
||||
CUDART_SUCCESS = 0,
|
||||
CUDART_ERROR_INVALID_VALUE = 1,
|
||||
CUDART_ERROR_MEMORY_ALLOCATION = 2,
|
||||
CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} cudartReturn_t;
|
||||
|
||||
typedef enum cudartDeviceAttr_enum {
|
||||
cudartDevAttrComputeCapabilityMajor = 75,
|
||||
cudartDevAttrComputeCapabilityMinor = 76,
|
||||
|
||||
// TODO - not yet wired up but may be useful for Jetson or other
|
||||
// integrated GPU scenarios with shared memory
|
||||
cudaDevAttrIntegrated = 18
|
||||
|
||||
} cudartDeviceAttr_t;
|
||||
|
||||
typedef void *cudartDevice_t; // Opaque is sufficient
|
||||
typedef struct cudartMemory_st {
|
||||
size_t total;
|
||||
size_t free;
|
||||
size_t used;
|
||||
} cudartMemory_t;
|
||||
|
||||
typedef struct cudartDriverVersion {
|
||||
int major;
|
||||
int minor;
|
||||
} cudartDriverVersion_t;
|
||||
|
||||
typedef struct cudaUUID {
|
||||
unsigned char bytes[16];
|
||||
} cudaUUID_t;
|
||||
typedef struct cudaDeviceProp {
|
||||
char name[256]; /**< ASCII string identifying device */
|
||||
cudaUUID_t uuid; /**< 16-byte unique identifier */
|
||||
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
|
||||
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
|
||||
size_t totalGlobalMem; /**< Global memory available on device in bytes */
|
||||
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
|
||||
int regsPerBlock; /**< 32-bit registers available per block */
|
||||
int warpSize; /**< Warp size in threads */
|
||||
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
|
||||
int maxThreadsPerBlock; /**< Maximum number of threads per block */
|
||||
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
|
||||
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
|
||||
int clockRate; /**< Clock frequency in kilohertz */
|
||||
size_t totalConstMem; /**< Constant memory available on device in bytes */
|
||||
int major; /**< Major compute capability */
|
||||
int minor; /**< Minor compute capability */
|
||||
size_t textureAlignment; /**< Alignment requirement for textures */
|
||||
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
|
||||
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
|
||||
int multiProcessorCount; /**< Number of multiprocessors on device */
|
||||
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
|
||||
int integrated; /**< Device is integrated as opposed to discrete */
|
||||
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
|
||||
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
|
||||
int maxTexture1D; /**< Maximum 1D texture size */
|
||||
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
|
||||
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
|
||||
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
|
||||
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
|
||||
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
|
||||
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
|
||||
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
|
||||
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
|
||||
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
|
||||
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
|
||||
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
|
||||
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
|
||||
int maxSurface1D; /**< Maximum 1D surface size */
|
||||
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
|
||||
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
|
||||
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
|
||||
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
|
||||
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
|
||||
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
|
||||
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
|
||||
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
|
||||
int ECCEnabled; /**< Device has ECC support enabled */
|
||||
int pciBusID; /**< PCI bus ID of the device */
|
||||
int pciDeviceID; /**< PCI device ID of the device */
|
||||
int pciDomainID; /**< PCI domain ID of the device */
|
||||
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
|
||||
int asyncEngineCount; /**< Number of asynchronous engines */
|
||||
int unifiedAddressing; /**< Device shares a unified address space with the host */
|
||||
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
|
||||
int memoryBusWidth; /**< Global memory bus width in bits */
|
||||
int l2CacheSize; /**< Size of L2 cache in bytes */
|
||||
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
|
||||
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
|
||||
int streamPrioritiesSupported; /**< Device supports stream priorities */
|
||||
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
|
||||
int localL1CacheSupported; /**< Device supports caching locals in L1 */
|
||||
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
|
||||
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
|
||||
int managedMemory; /**< Device supports allocating managed memory on this system */
|
||||
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
|
||||
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
|
||||
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
|
||||
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
|
||||
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
|
||||
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
|
||||
int computePreemptionSupported; /**< Device supports Compute Preemption */
|
||||
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
|
||||
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
|
||||
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
|
||||
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
|
||||
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
|
||||
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
|
||||
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
|
||||
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
|
||||
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
|
||||
} cudaDeviceProp_t;
|
||||
|
||||
typedef struct cudart_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
cudartReturn_t (*cudaSetDevice)(int device);
|
||||
cudartReturn_t (*cudaDeviceSynchronize)(void);
|
||||
cudartReturn_t (*cudaDeviceReset)(void);
|
||||
cudartReturn_t (*cudaMemGetInfo)(size_t *, size_t *);
|
||||
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
||||
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
||||
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
||||
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
|
||||
} cudart_handle_t;
|
||||
|
||||
typedef struct cudart_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
cudart_handle_t ch;
|
||||
int num_devices;
|
||||
} cudart_init_resp_t;
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||
void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||
// TODO - if we keep this library longer term, add cudart_get_free
|
||||
void cudart_release(cudart_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_CUDART_H__
|
||||
#endif // __APPLE__
|
||||
@@ -1,251 +0,0 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_nvcuda.h"
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "initializing %s\n", nvcuda_lib_path);
|
||||
CUresult ret;
|
||||
resp->err = NULL;
|
||||
resp->num_devices = 0;
|
||||
resp->cudaErr = CUDA_SUCCESS;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
|
||||
{"cuInit", (void *)&resp->ch.cuInit},
|
||||
{"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
|
||||
{"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
|
||||
{"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
|
||||
{"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
|
||||
{"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
|
||||
{"cuDeviceGetName", (void *)&resp->ch.cuDeviceGetName},
|
||||
{"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
|
||||
{"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
|
||||
{"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
nvcuda_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
resp->cudaErr = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!*(l[i].p)) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
resp->cudaErr = -1;
|
||||
return;
|
||||
}
|
||||
LOG(resp->ch.verbose, "dlsym: %s - %p\n", l[i].s, *l[i].p);
|
||||
}
|
||||
|
||||
LOG(resp->ch.verbose, "calling cuInit\n");
|
||||
ret = (*resp->ch.cuInit)(0);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "cuda driver library init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
resp->cudaErr = ret;
|
||||
return;
|
||||
}
|
||||
|
||||
int version = 0;
|
||||
resp->ch.driver_major = 0;
|
||||
resp->ch.driver_minor = 0;
|
||||
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
LOG(resp->ch.verbose, "calling cuDriverGetVersion\n");
|
||||
ret = (*resp->ch.cuDriverGetVersion)(&version);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(resp->ch.verbose, "raw version 0x%x\n", version);
|
||||
resp->ch.driver_major = version / 1000;
|
||||
resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d.%d\n", resp->ch.driver_major, resp->ch.driver_minor);
|
||||
}
|
||||
|
||||
LOG(resp->ch.verbose, "calling cuDeviceGetCount\n");
|
||||
ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
resp->cudaErr = ret;
|
||||
return;
|
||||
}
|
||||
LOG(resp->ch.verbose, "device count %d\n", resp->num_devices);
|
||||
}
|
||||
|
||||
const int buflen = 256;
|
||||
void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
nvcudaMemory_t memInfo = {0,0};
|
||||
CUresult ret;
|
||||
CUdevice device = -1;
|
||||
CUcontext ctx = NULL;
|
||||
char buf[buflen + 1];
|
||||
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("cuda driver library handle isn't initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuDeviceGet)(&device, i);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "cuda driver library device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
int major = 0;
|
||||
int minor = 0;
|
||||
ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
|
||||
} else {
|
||||
ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
|
||||
} else {
|
||||
resp->minor = minor;
|
||||
resp->major = major;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*h.cuDeviceGetUuid)(&uuid, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
} else {
|
||||
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
uuid.bytes[0],
|
||||
uuid.bytes[1],
|
||||
uuid.bytes[2],
|
||||
uuid.bytes[3],
|
||||
uuid.bytes[4],
|
||||
uuid.bytes[5],
|
||||
uuid.bytes[6],
|
||||
uuid.bytes[7],
|
||||
uuid.bytes[8],
|
||||
uuid.bytes[9],
|
||||
uuid.bytes[10],
|
||||
uuid.bytes[11],
|
||||
uuid.bytes[12],
|
||||
uuid.bytes[13],
|
||||
uuid.bytes[14],
|
||||
uuid.bytes[15]
|
||||
);
|
||||
}
|
||||
|
||||
ret = (*h.cuDeviceGetName)(&resp->gpu_name[0], GPU_NAME_LEN, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device name lookup failure: %d\n", i, ret);
|
||||
resp->gpu_name[0] = '\0';
|
||||
}
|
||||
|
||||
// To get memory we have to set (and release) a context
|
||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
// Best effort on failure...
|
||||
(*h.cuCtxDestroy)(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
|
||||
|
||||
|
||||
ret = (*h.cuCtxDestroy)(ctx);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "cuda driver library failed to release device context %d", ret);
|
||||
}
|
||||
}
|
||||
|
||||
void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) {
|
||||
CUresult ret;
|
||||
CUcontext ctx = NULL;
|
||||
CUdevice device = -1;
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
ret = (*h.cuDeviceGet)(&device, i);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "cuda driver library device failed to initialize");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// To get memory we have to set (and release) a context
|
||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "cuda driver library failed to get device context %d", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuMemGetInfo_v2)(free, total);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "cuda driver library device memory info lookup failure %d", ret);
|
||||
// Best effort on failure...
|
||||
(*h.cuCtxDestroy)(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuCtxDestroy)(ctx);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "cuda driver library failed to release device context %d", ret);
|
||||
}
|
||||
}
|
||||
|
||||
void nvcuda_release(nvcuda_handle_t h) {
|
||||
LOG(h.verbose, "releasing cuda driver library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
// TODO and other context release logic?
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -1,79 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_NVCUDA_H__
|
||||
#define __GPU_INFO_NVCUDA_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum cudaError_enum {
|
||||
CUDA_SUCCESS = 0,
|
||||
CUDA_ERROR_INVALID_VALUE = 1,
|
||||
CUDA_ERROR_OUT_OF_MEMORY = 2,
|
||||
CUDA_ERROR_NOT_INITIALIZED = 3,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
CUDA_ERROR_NO_DEVICE = 100,
|
||||
CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
|
||||
CUDA_ERROR_UNKNOWN = 999,
|
||||
// Other values omitted for now...
|
||||
} CUresult;
|
||||
|
||||
typedef enum CUdevice_attribute_enum {
|
||||
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
|
||||
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
|
||||
|
||||
// TODO - not yet wired up but may be useful for Jetson or other
|
||||
// integrated GPU scenarios with shared memory
|
||||
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
|
||||
|
||||
} CUdevice_attribute;
|
||||
|
||||
typedef void *nvcudaDevice_t; // Opaque is sufficient
|
||||
typedef struct nvcudaMemory_st {
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
} nvcudaMemory_t;
|
||||
|
||||
typedef struct nvcudaDriverVersion {
|
||||
int major;
|
||||
int minor;
|
||||
} nvcudaDriverVersion_t;
|
||||
|
||||
typedef struct CUuuid_st {
|
||||
unsigned char bytes[16];
|
||||
} CUuuid;
|
||||
|
||||
typedef int CUdevice;
|
||||
typedef void* CUcontext;
|
||||
|
||||
typedef struct nvcuda_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
int driver_major;
|
||||
int driver_minor;
|
||||
CUresult (*cuInit)(unsigned int Flags);
|
||||
CUresult (*cuDriverGetVersion)(int *driverVersion);
|
||||
CUresult (*cuDeviceGetCount)(int *);
|
||||
CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
|
||||
CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
|
||||
CUresult (*cuDeviceGetName)(char *name, int len, CUdevice dev);
|
||||
|
||||
// Context specific aspects
|
||||
CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
|
||||
CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
|
||||
CUresult (*cuCtxDestroy)(CUcontext ctx);
|
||||
} nvcuda_handle_t;
|
||||
|
||||
typedef struct nvcuda_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
nvcuda_handle_t ch;
|
||||
int num_devices;
|
||||
CUresult cudaErr;
|
||||
} nvcuda_init_resp_t;
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
||||
void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void nvcuda_get_free(nvcuda_handle_t ch, int device_id, uint64_t *free, uint64_t *total);
|
||||
void nvcuda_release(nvcuda_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVCUDA_H__
|
||||
#endif // __APPLE__
|
||||
@@ -1,104 +0,0 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "gpu_info_nvml.h"
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
||||
nvmlReturn_t ret;
|
||||
resp->err = NULL;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
||||
{"nvmlDeviceGetHandleByUUID", (void *)&resp->ch.nvmlDeviceGetHandleByUUID},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
nvml_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
// LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
// LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!*(l[i].p)) {
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.nvmlInit_v2)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void nvml_get_free(nvml_handle_t h, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used) {
|
||||
nvmlDevice_t device;
|
||||
nvmlMemory_t memInfo = {0};
|
||||
nvmlReturn_t ret;
|
||||
ret = (*h.nvmlDeviceGetHandleByUUID)((const char *)(uuid), &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "unable to get device handle %s: %d", uuid, ret);
|
||||
*free = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "device memory info lookup failure %s: %d", uuid, ret);
|
||||
*free = 0;
|
||||
return;
|
||||
}
|
||||
*free = memInfo.free;
|
||||
*total = memInfo.total;
|
||||
*used = memInfo.used;
|
||||
}
|
||||
|
||||
|
||||
void nvml_release(nvml_handle_t h) {
|
||||
LOG(h.verbose, "releasing nvml library\n");
|
||||
nvmlReturn_t ret;
|
||||
ret = (*h.nvmlShutdown)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "error during nvmlShutdown %d", ret);
|
||||
}
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -1,48 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_NVML_H__
|
||||
#define __GPU_INFO_NVML_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum nvmlReturn_enum {
|
||||
NVML_SUCCESS = 0,
|
||||
// Other values omitted for now...
|
||||
} nvmlReturn_t;
|
||||
typedef void *nvmlDevice_t; // Opaque is sufficient
|
||||
typedef struct nvmlMemory_st {
|
||||
unsigned long long total;
|
||||
unsigned long long free;
|
||||
unsigned long long used;
|
||||
} nvmlMemory_t;
|
||||
|
||||
typedef enum nvmlBrandType_enum
|
||||
{
|
||||
NVML_BRAND_UNKNOWN = 0,
|
||||
} nvmlBrandType_t;
|
||||
|
||||
typedef struct nvml_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
} nvml_handle_t;
|
||||
|
||||
typedef struct nvml_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
nvml_handle_t ch;
|
||||
} nvml_init_resp_t;
|
||||
|
||||
typedef struct nvml_compute_capability {
|
||||
char *err;
|
||||
int major;
|
||||
int minor;
|
||||
} nvml_compute_capability_t;
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
||||
void nvml_get_free(nvml_handle_t ch, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used);
|
||||
void nvml_release(nvml_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVML_H__
|
||||
#endif // __APPLE__
|
||||
@@ -1,259 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
|
||||
#include "gpu_info_oneapi.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
||||
ze_result_t ret;
|
||||
resp->err = NULL;
|
||||
resp->oh.devices = NULL;
|
||||
resp->oh.num_devices = NULL;
|
||||
resp->oh.drivers = NULL;
|
||||
resp->oh.num_drivers = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i, d;
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"zesInit", (void *)&resp->oh.zesInit},
|
||||
{"zesDriverGet", (void *)&resp->oh.zesDriverGet},
|
||||
{"zesDeviceGet", (void *)&resp->oh.zesDeviceGet},
|
||||
{"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties},
|
||||
{"zesDeviceEnumMemoryModules",
|
||||
(void *)&resp->oh.zesDeviceEnumMemoryModules},
|
||||
{"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties},
|
||||
{"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
|
||||
if (!resp->oh.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Intel GPUs: %s\n",
|
||||
oneapi_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->oh.verbose,
|
||||
"wiring Level-Zero management library functions in %s\n",
|
||||
oneapi_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
|
||||
if (!*(l[i].p)) {
|
||||
resp->oh.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->oh.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(resp->oh.verbose, "calling zesInit\n");
|
||||
|
||||
ret = (*resp->oh.zesInit)(0);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
|
||||
snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(resp->oh.verbose, "calling zesDriverGet\n");
|
||||
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get driver count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
|
||||
resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
|
||||
resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
|
||||
memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
|
||||
resp->oh.devices =
|
||||
malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *));
|
||||
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get driver count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
|
||||
for (d = 0; d < resp->oh.num_drivers; d++) {
|
||||
LOG(resp->oh.verbose, "calling zesDeviceGet count %d: %p\n", d, resp->oh.drivers[d]);
|
||||
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d],
|
||||
&resp->oh.num_devices[d], NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
resp->oh.devices[d] =
|
||||
malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
|
||||
ret = (*resp->oh.zesDeviceGet)(
|
||||
resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
|
||||
mem_info_t *resp) {
|
||||
ze_result_t ret;
|
||||
resp->err = NULL;
|
||||
uint64_t totalMem = 0;
|
||||
uint64_t usedMem = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i, d, m;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("Level-Zero handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
if (driver > h.num_drivers || device > h.num_devices[driver]) {
|
||||
resp->err = strdup("driver of device index out of bounds");
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
|
||||
zes_device_ext_properties_t ext_props;
|
||||
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
|
||||
ext_props.pNext = NULL;
|
||||
|
||||
zes_device_properties_t props;
|
||||
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
|
||||
props.pNext = &ext_props;
|
||||
|
||||
ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device properties: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", props.modelName);
|
||||
|
||||
// TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
|
||||
// (this is probably wrong...)
|
||||
// TODO - the driver isn't included - what if there are multiple drivers?
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
|
||||
|
||||
if (h.verbose) {
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover.
|
||||
LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
|
||||
props.modelName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
|
||||
props.brandName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
|
||||
props.vendorName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
|
||||
props.serialNumber);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
|
||||
props.boardNumber);
|
||||
}
|
||||
|
||||
// TODO
|
||||
// Compute Capability equivalent in resp->major, resp->minor, resp->patch
|
||||
|
||||
uint32_t memCount = 0;
|
||||
ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount,
|
||||
NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x",
|
||||
ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
|
||||
|
||||
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
|
||||
(*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
|
||||
|
||||
for (m = 0; m < memCount; m++) {
|
||||
zes_mem_state_t state;
|
||||
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
|
||||
state.pNext = NULL;
|
||||
ret = (*h.zesMemoryGetState)(mems[m], &state);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get memory state: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
free(mems);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total += state.size;
|
||||
resp->free += state.free;
|
||||
}
|
||||
|
||||
free(mems);
|
||||
}
|
||||
|
||||
void oneapi_release(oneapi_handle_t h) {
|
||||
int d;
|
||||
LOG(h.verbose, "releasing oneapi library\n");
|
||||
for (d = 0; d < h.num_drivers; d++) {
|
||||
if (h.devices != NULL && h.devices[d] != NULL) {
|
||||
free(h.devices[d]);
|
||||
}
|
||||
}
|
||||
if (h.devices != NULL) {
|
||||
free(h.devices);
|
||||
h.devices = NULL;
|
||||
}
|
||||
if (h.num_devices != NULL) {
|
||||
free(h.num_devices);
|
||||
h.num_devices = NULL;
|
||||
}
|
||||
if (h.drivers != NULL) {
|
||||
free(h.drivers);
|
||||
h.drivers = NULL;
|
||||
}
|
||||
h.num_drivers = 0;
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
int oneapi_get_device_count(oneapi_handle_t h, int driver) {
|
||||
if (h.handle == NULL || h.num_devices == NULL) {
|
||||
return 0;
|
||||
}
|
||||
if (driver > h.num_drivers) {
|
||||
return 0;
|
||||
}
|
||||
return (int)h.num_devices[driver];
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -1,203 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_ONEAPI_H__
|
||||
#define __GPU_INFO_ONEAPI_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
#define ZE_MAX_DEVICE_NAME 256
|
||||
#define ZE_MAX_DEVICE_UUID_SIZE 16
|
||||
#define ZES_STRING_PROPERTY_SIZE 64
|
||||
#define ZE_BIT(_i) (1 << _i)
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum ze_result_t {
|
||||
ZE_RESULT_SUCCESS = 0,
|
||||
// Other values omitted for now...
|
||||
} ze_result_t;
|
||||
|
||||
typedef uint8_t ze_bool_t;
|
||||
typedef struct _zes_driver_handle_t *zes_driver_handle_t;
|
||||
typedef struct _zes_device_handle_t *zes_device_handle_t;
|
||||
typedef struct _zes_mem_handle_t *zes_mem_handle_t;
|
||||
|
||||
typedef enum _ze_structure_type_t {
|
||||
ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} ze_structure_type_t;
|
||||
|
||||
typedef enum _zes_structure_type_t {
|
||||
ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
|
||||
ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
|
||||
ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
|
||||
ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d,
|
||||
ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_structure_type_t;
|
||||
|
||||
typedef enum _zes_mem_type_t {
|
||||
ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_type_t;
|
||||
|
||||
typedef enum _zes_mem_loc_t {
|
||||
ZES_MEM_LOC_SYSTEM = 0,
|
||||
ZES_MEM_LOC_DEVICE = 1,
|
||||
ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_loc_t;
|
||||
|
||||
typedef enum _zes_mem_health_t {
|
||||
ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_health_t;
|
||||
|
||||
typedef struct _ze_device_uuid_t {
|
||||
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||
} ze_device_uuid_t;
|
||||
|
||||
typedef struct _zes_uuid_t {
|
||||
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||
} zes_uuid_t;
|
||||
|
||||
typedef enum _ze_device_type_t {
|
||||
ZE_DEVICE_TYPE_GPU = 1,
|
||||
ZE_DEVICE_TYPE_CPU = 2,
|
||||
ZE_DEVICE_TYPE_FPGA = 3,
|
||||
ZE_DEVICE_TYPE_MCA = 4,
|
||||
ZE_DEVICE_TYPE_VPU = 5,
|
||||
ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} ze_device_type_t;
|
||||
|
||||
typedef enum _zes_device_type_t {
|
||||
ZES_DEVICE_TYPE_GPU = 1,
|
||||
ZES_DEVICE_TYPE_CPU = 2,
|
||||
ZES_DEVICE_TYPE_FPGA = 3,
|
||||
ZES_DEVICE_TYPE_MCA = 4,
|
||||
ZES_DEVICE_TYPE_VPU = 5,
|
||||
ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_device_type_t;
|
||||
|
||||
typedef uint32_t ze_device_property_flags_t;
|
||||
typedef enum _ze_device_property_flag_t {
|
||||
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||
ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||
ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||
ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
|
||||
ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
|
||||
} ze_device_property_flag_t;
|
||||
|
||||
typedef uint32_t zes_device_property_flags_t;
|
||||
typedef enum _zes_device_property_flag_t {
|
||||
ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||
ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||
ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||
ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
|
||||
ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_device_property_flag_t;
|
||||
|
||||
typedef struct _ze_device_properties_t {
|
||||
ze_structure_type_t stype;
|
||||
void *pNext;
|
||||
ze_device_type_t type;
|
||||
uint32_t vendorId;
|
||||
uint32_t deviceId;
|
||||
ze_device_property_flags_t flags;
|
||||
uint32_t subdeviceId;
|
||||
uint32_t coreClockRate;
|
||||
uint64_t maxMemAllocSize;
|
||||
uint32_t maxHardwareContexts;
|
||||
uint32_t maxCommandQueuePriority;
|
||||
uint32_t numThreadsPerEU;
|
||||
uint32_t physicalEUSimdWidth;
|
||||
uint32_t numEUsPerSubslice;
|
||||
uint32_t numSubslicesPerSlice;
|
||||
uint32_t numSlices;
|
||||
uint64_t timerResolution;
|
||||
uint32_t timestampValidBits;
|
||||
uint32_t kernelTimestampValidBits;
|
||||
ze_device_uuid_t uuid;
|
||||
char name[ZE_MAX_DEVICE_NAME];
|
||||
} ze_device_properties_t;
|
||||
|
||||
typedef struct _zes_device_properties_t {
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
ze_device_properties_t core;
|
||||
uint32_t numSubdevices;
|
||||
char serialNumber[ZES_STRING_PROPERTY_SIZE];
|
||||
char boardNumber[ZES_STRING_PROPERTY_SIZE];
|
||||
char brandName[ZES_STRING_PROPERTY_SIZE];
|
||||
char modelName[ZES_STRING_PROPERTY_SIZE];
|
||||
char vendorName[ZES_STRING_PROPERTY_SIZE];
|
||||
char driverVersion[ZES_STRING_PROPERTY_SIZE];
|
||||
} zes_device_properties_t;
|
||||
|
||||
typedef struct _zes_device_ext_properties_t {
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
zes_uuid_t uuid;
|
||||
zes_device_type_t type;
|
||||
zes_device_property_flags_t flags;
|
||||
} zes_device_ext_properties_t;
|
||||
|
||||
typedef struct _zes_mem_properties_t {
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
zes_mem_type_t type;
|
||||
ze_bool_t onSubdevice;
|
||||
uint32_t subdeviceId;
|
||||
zes_mem_loc_t location;
|
||||
uint64_t physicalSize;
|
||||
int32_t busWidth;
|
||||
int32_t numChannels;
|
||||
} zes_mem_properties_t;
|
||||
|
||||
typedef struct _zes_mem_state_t {
|
||||
zes_structure_type_t stype;
|
||||
const void *pNext;
|
||||
zes_mem_health_t health;
|
||||
uint64_t free;
|
||||
uint64_t size;
|
||||
} zes_mem_state_t;
|
||||
|
||||
typedef struct oneapi_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
|
||||
uint32_t num_drivers;
|
||||
zes_driver_handle_t *drivers;
|
||||
uint32_t *num_devices;
|
||||
zes_device_handle_t **devices;
|
||||
|
||||
// TODO Driver major, minor information
|
||||
// int driver_major;
|
||||
// int driver_minor;
|
||||
|
||||
ze_result_t (*zesInit)(int);
|
||||
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
|
||||
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
|
||||
zes_device_handle_t *phDevices);
|
||||
ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice,
|
||||
zes_device_properties_t *pProperties);
|
||||
ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice,
|
||||
uint32_t *pCount,
|
||||
zes_mem_handle_t *phMemory);
|
||||
ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory,
|
||||
zes_mem_properties_t *pProperties);
|
||||
ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory,
|
||||
zes_mem_state_t *pState);
|
||||
|
||||
} oneapi_handle_t;
|
||||
|
||||
typedef struct oneapi_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
oneapi_handle_t oh;
|
||||
} oneapi_init_resp_t;
|
||||
|
||||
typedef struct oneapi_version_resp {
|
||||
ze_result_t status;
|
||||
char *str; // Contains version or error string if status != 0
|
||||
} oneapi_version_resp_t;
|
||||
|
||||
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
|
||||
void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
|
||||
mem_info_t *resp);
|
||||
void oneapi_release(oneapi_handle_t h);
|
||||
int oneapi_get_device_count(oneapi_handle_t h, int driver);
|
||||
|
||||
#endif // __GPU_INFO_INTEL_H__
|
||||
#endif // __APPLE__
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "oneapi" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicGetGPUInfo(t *testing.T) {
|
||||
info := GetGPUInfo()
|
||||
assert.NotEmpty(t, len(info))
|
||||
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||
if info[0].Library != "cpu" {
|
||||
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||
assert.Greater(t, info[0].FreeMemory, uint64(0))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPUMemInfo(t *testing.T) {
|
||||
info, err := GetCPUMem()
|
||||
require.NoError(t, err)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
t.Skip("CPU memory not populated on darwin")
|
||||
case "linux", "windows":
|
||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestByLibrary(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []GpuInfo
|
||||
expect int
|
||||
}
|
||||
|
||||
testCases := map[string]*testCase{
|
||||
"empty": {input: []GpuInfo{}, expect: 0},
|
||||
"cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1},
|
||||
"cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2},
|
||||
"cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2},
|
||||
"cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2},
|
||||
"cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3},
|
||||
}
|
||||
|
||||
for k, v := range testCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
resp := (GpuInfoList)(v.input).ByLibrary()
|
||||
if len(resp) != v.expect {
|
||||
t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
488
discover/runner.go
Normal file
488
discover/runner.go
Normal file
@@ -0,0 +1,488 @@
|
||||
package discover
|
||||
|
||||
// Runner based GPU discovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
deviceMu sync.Mutex
|
||||
devices []ml.DeviceInfo
|
||||
libDirs map[string]struct{}
|
||||
rocmDir string
|
||||
exe string
|
||||
bootstrapped bool
|
||||
)
|
||||
|
||||
func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
deviceMu.Lock()
|
||||
defer deviceMu.Unlock()
|
||||
startDiscovery := time.Now()
|
||||
msg := "overall device VRAM discovery took"
|
||||
defer func() {
|
||||
slog.Debug(msg, "duration", time.Since(startDiscovery))
|
||||
}()
|
||||
|
||||
if !bootstrapped {
|
||||
msg = "GPU bootstrap discovery took"
|
||||
libDirs = make(map[string]struct{})
|
||||
var err error
|
||||
exe, err = os.Executable()
|
||||
if err != nil {
|
||||
slog.Error("unable to lookup executable path", "error", err)
|
||||
return nil
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
files, err := filepath.Glob(filepath.Join(LibOllamaPath, "*", "*ggml-*"))
|
||||
if err != nil {
|
||||
slog.Debug("unable to lookup runner library directories", "error", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
libDirs[filepath.Dir(file)] = struct{}{}
|
||||
}
|
||||
|
||||
// Our current packaging model places ggml-hip in the main directory
|
||||
// but keeps rocm in an isolated directory. We have to add it to
|
||||
// the [LD_LIBRARY_]PATH so ggml-hip will load properly
|
||||
rocmDir = filepath.Join(LibOllamaPath, "rocm")
|
||||
if _, err := os.Stat(rocmDir); err != nil {
|
||||
rocmDir = ""
|
||||
}
|
||||
|
||||
if len(libDirs) == 0 {
|
||||
libDirs[""] = struct{}{}
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
requested := envconfig.LLMLibrary()
|
||||
jetpack := cudaJetpack()
|
||||
|
||||
// For our initial discovery pass, we gather all the known GPUs through
|
||||
// all the libraries that were detected. This pass may include GPUs that
|
||||
// are enumerated, but not actually supported.
|
||||
// We run this in serial to avoid potentially initializing a GPU multiple
|
||||
// times concurrently leading to memory contention
|
||||
// TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs
|
||||
for dir := range libDirs {
|
||||
bootstrapTimeout := 30 * time.Second
|
||||
var dirs []string
|
||||
if dir != "" {
|
||||
if requested != "" && filepath.Base(dir) != requested {
|
||||
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dir == "" {
|
||||
dirs = []string{LibOllamaPath}
|
||||
} else {
|
||||
dirs = []string{LibOllamaPath, dir}
|
||||
}
|
||||
|
||||
// ROCm can take a long time on some systems, so give it more time before giving up
|
||||
if dir != "" && strings.Contains(filepath.Base(dir), "rocm") {
|
||||
bootstrapTimeout = 60 * time.Second
|
||||
}
|
||||
// Typically bootstrapping takes < 1s, but on some systems, with devices
|
||||
// in low power/idle mode, initialization can take multiple seconds. We
|
||||
// set a long timeout just for bootstrap discovery to reduce the chance
|
||||
// of giving up too quickly
|
||||
ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout)
|
||||
defer cancel()
|
||||
|
||||
// For this pass, we retain duplicates in case any are incompatible with some libraries
|
||||
devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...)
|
||||
}
|
||||
|
||||
// In the second pass, we more deeply initialize the GPUs to weed out devices that
|
||||
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
|
||||
slog.Debug("filtering out unsupported or overlapping GPU library combinations", "count", len(devices))
|
||||
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
needsDelete := make([]bool, len(devices))
|
||||
supportedMu := sync.Mutex{}
|
||||
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
|
||||
for i := range devices {
|
||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||
if devices[i].Library == "Metal" {
|
||||
continue
|
||||
}
|
||||
slog.Debug("verifying GPU is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID)
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
var envVar string
|
||||
id := devices[i].ID
|
||||
if devices[i].Library == "ROCm" {
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
} else {
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
}
|
||||
} else if devices[i].Library == "CUDA" {
|
||||
envVar = "CUDA_VISIBLE_DEVICES"
|
||||
} else if devices[i].Library == "Vulkan" {
|
||||
id = devices[i].FilteredID
|
||||
envVar = "GGML_VK_VISIBLE_DEVICES"
|
||||
} else {
|
||||
slog.Error("Unknown Library:" + devices[i].Library)
|
||||
}
|
||||
|
||||
extraEnvs := map[string]string{
|
||||
"GGML_CUDA_INIT": "1", // force deep initialization to trigger crash on unsupported GPUs
|
||||
envVar: id, // Filter to just this one GPU
|
||||
}
|
||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
needsDelete[i] = true
|
||||
} else {
|
||||
supportedMu.Lock()
|
||||
if _, ok := supported[devices[i].Library]; !ok {
|
||||
supported[devices[i].Library] = make(map[string]map[string]int)
|
||||
}
|
||||
if _, ok := supported[devices[i].Library][libDir]; !ok {
|
||||
supported[devices[i].Library][libDir] = make(map[string]int)
|
||||
}
|
||||
supported[devices[i].Library][libDir][devices[i].ID] = i
|
||||
supportedMu.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
logutil.Trace("supported GPU library combinations", "supported", supported)
|
||||
|
||||
filterOutVulkanThatAreSupportedByOtherGPU(needsDelete)
|
||||
|
||||
// Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible
|
||||
filterOverlapByLibrary(supported, needsDelete)
|
||||
|
||||
// TODO if we ever support multiple ROCm library versions this algorithm will need to be adjusted to keep the rocmID numeric value correct
|
||||
rocmID := 0
|
||||
for i := 0; i < len(needsDelete); i++ {
|
||||
if needsDelete[i] {
|
||||
logutil.Trace("removing unsupported or overlapping GPU combination", "libDir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID)
|
||||
devices = append(devices[:i], devices[i+1:]...)
|
||||
needsDelete = append(needsDelete[:i], needsDelete[i+1:]...)
|
||||
i--
|
||||
} else if devices[i].Library == "ROCm" {
|
||||
if _, err := strconv.Atoi(devices[i].ID); err == nil {
|
||||
// Replace the numeric ID with the post-filtered IDs
|
||||
devices[i].FilteredID = devices[i].ID
|
||||
devices[i].ID = strconv.Itoa(rocmID)
|
||||
}
|
||||
rocmID++
|
||||
}
|
||||
}
|
||||
|
||||
// Now filter out any overlap with different libraries (favor CUDA/HIP over others)
|
||||
for i := 0; i < len(devices); i++ {
|
||||
for j := i + 1; j < len(devices); j++ {
|
||||
// For this pass, we only drop exact duplicates
|
||||
switch devices[i].Compare(devices[j]) {
|
||||
case ml.SameBackendDevice:
|
||||
// Same library and device, skip it
|
||||
devices = append(devices[:j], devices[j+1:]...)
|
||||
j--
|
||||
continue
|
||||
case ml.DuplicateDevice:
|
||||
// Different library, choose based on priority
|
||||
var droppedDevice ml.DeviceInfo
|
||||
if devices[i].Library == "CUDA" || devices[i].Library == "ROCm" {
|
||||
droppedDevice = devices[j]
|
||||
} else {
|
||||
droppedDevice = devices[i]
|
||||
devices[i] = devices[j]
|
||||
}
|
||||
devices = append(devices[:j], devices[j+1:]...)
|
||||
j--
|
||||
|
||||
typeStr := "discrete"
|
||||
if droppedDevice.Integrated {
|
||||
typeStr = "iGPU"
|
||||
}
|
||||
slog.Debug("dropping duplicate device",
|
||||
"id", droppedDevice.ID,
|
||||
"library", droppedDevice.Library,
|
||||
"compute", droppedDevice.Compute(),
|
||||
"name", droppedDevice.Name,
|
||||
"description", droppedDevice.Description,
|
||||
"libdirs", strings.Join(droppedDevice.LibraryPath, ","),
|
||||
"driver", droppedDevice.Driver(),
|
||||
"pci_id", droppedDevice.PCIID,
|
||||
"type", typeStr,
|
||||
"total", format.HumanBytes2(droppedDevice.TotalMemory),
|
||||
"available", format.HumanBytes2(droppedDevice.FreeMemory),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the libDirs to what we actually wind up using for future refreshes
|
||||
libDirs = make(map[string]struct{})
|
||||
for _, dev := range devices {
|
||||
dir := dev.LibraryPath[len(dev.LibraryPath)-1]
|
||||
if dir != LibOllamaPath {
|
||||
libDirs[dir] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(libDirs) == 0 {
|
||||
libDirs[""] = struct{}{}
|
||||
}
|
||||
|
||||
bootstrapped = true
|
||||
} else {
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
// metal never updates free VRAM
|
||||
return devices
|
||||
}
|
||||
|
||||
slog.Debug("refreshing free memory")
|
||||
updated := make([]bool, len(devices))
|
||||
allDone := func() bool {
|
||||
allDone := true
|
||||
for _, done := range updated {
|
||||
if !done {
|
||||
allDone = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return allDone
|
||||
}
|
||||
|
||||
// First try to use existing runners to refresh VRAM since they're already
|
||||
// active on GPU(s)
|
||||
for _, runner := range runners {
|
||||
if runner == nil {
|
||||
continue
|
||||
}
|
||||
deviceIDs := runner.GetActiveDeviceIDs()
|
||||
if len(deviceIDs) == 0 {
|
||||
// Skip this runner since it doesn't have active GPU devices
|
||||
continue
|
||||
}
|
||||
|
||||
// Check to see if this runner is active on any devices that need a refresh
|
||||
skip := true
|
||||
devCheck:
|
||||
for _, dev := range deviceIDs {
|
||||
for i := range devices {
|
||||
if dev == devices[i].DeviceID {
|
||||
if !updated[i] {
|
||||
skip = false
|
||||
break devCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
// Typical refresh on existing runner is ~500ms but allow longer if the system
|
||||
// is under stress before giving up and using stale data.
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
updatedDevices := runner.GetDeviceInfos(ctx)
|
||||
slog.Debug("existing runner discovery took", "duration", time.Since(start))
|
||||
for _, u := range updatedDevices {
|
||||
for i := range devices {
|
||||
if u.DeviceID == devices[i].DeviceID {
|
||||
updated[i] = true
|
||||
devices[i].FreeMemory = u.FreeMemory
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Short circuit if we've updated all the devices
|
||||
if allDone() {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allDone() {
|
||||
slog.Debug("unable to refresh all GPUs with existing runners, performing bootstrap discovery")
|
||||
|
||||
// Bootstrapping may take longer in some cases (AMD windows), but we
|
||||
// would rather use stale free data to get the model running sooner
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for dir := range libDirs {
|
||||
updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil)
|
||||
for _, u := range updatedDevices {
|
||||
for i := range devices {
|
||||
if u.DeviceID == devices[i].DeviceID {
|
||||
updated[i] = true
|
||||
devices[i].FreeMemory = u.FreeMemory
|
||||
break
|
||||
}
|
||||
}
|
||||
// TODO - consider evaluating if new devices have appeared (e.g. hotplug)
|
||||
}
|
||||
if allDone() {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allDone() {
|
||||
slog.Warn("unable to refresh free memory, using old values")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) {
|
||||
// Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion
|
||||
for i := range devices {
|
||||
if devices[i].Library != "Vulkan" || needsDelete[i] {
|
||||
continue
|
||||
}
|
||||
if devices[i].PCIID == "" {
|
||||
continue
|
||||
}
|
||||
for j := range devices {
|
||||
if i == j {
|
||||
continue
|
||||
}
|
||||
if devices[j].PCIID == "" {
|
||||
continue
|
||||
}
|
||||
if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] {
|
||||
needsDelete[i] = true
|
||||
slog.Debug("dropping Vulkan duplicate by PCI ID",
|
||||
"vulkan_id", devices[i].ID,
|
||||
"vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
|
||||
"pci_id", devices[i].PCIID,
|
||||
"kept_library", devices[j].Library,
|
||||
"kept_id", devices[j].ID,
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) {
|
||||
// For multi-GPU systems, use the newest version that supports all the GPUs
|
||||
for _, byLibDirs := range supported {
|
||||
libDirs := make([]string, 0, len(byLibDirs))
|
||||
for libDir := range byLibDirs {
|
||||
libDirs = append(libDirs, libDir)
|
||||
}
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(libDirs)))
|
||||
anyMissing := false
|
||||
var newest string
|
||||
for _, newest = range libDirs {
|
||||
for _, libDir := range libDirs {
|
||||
if libDir == newest {
|
||||
continue
|
||||
}
|
||||
if len(byLibDirs[newest]) != len(byLibDirs[libDir]) {
|
||||
anyMissing = true
|
||||
break
|
||||
}
|
||||
for dev := range byLibDirs[newest] {
|
||||
if _, found := byLibDirs[libDir][dev]; !found {
|
||||
anyMissing = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !anyMissing {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Now we can mark overlaps for deletion
|
||||
for _, libDir := range libDirs {
|
||||
if libDir == newest {
|
||||
continue
|
||||
}
|
||||
for dev, i := range byLibDirs[libDir] {
|
||||
if _, found := byLibDirs[newest][dev]; found {
|
||||
needsDelete[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type bootstrapRunner struct {
|
||||
port int
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (r *bootstrapRunner) GetPort() int {
|
||||
return r.port
|
||||
}
|
||||
|
||||
func (r *bootstrapRunner) HasExited() bool {
|
||||
if r.cmd != nil && r.cmd.ProcessState != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
|
||||
var out io.Writer
|
||||
if envconfig.LogLevel() == logutil.LevelTrace {
|
||||
out = os.Stderr
|
||||
}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
}()
|
||||
|
||||
logutil.Trace("starting runner for device discovery", "libDirs", ollamaLibDirs, "extraEnvs", extraEnvs)
|
||||
cmd, port, err := llm.StartRunner(
|
||||
true, // ollama engine
|
||||
"", // no model
|
||||
ollamaLibDirs,
|
||||
out,
|
||||
extraEnvs,
|
||||
)
|
||||
if err != nil {
|
||||
slog.Debug("failed to start runner to discovery GPUs", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
cmd.Wait() // exit status ignored
|
||||
}()
|
||||
|
||||
defer cmd.Process.Kill()
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
|
||||
if err != nil {
|
||||
if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 {
|
||||
// Expected during bootstrapping while we filter out unsupported AMD GPUs
|
||||
logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode())
|
||||
} else {
|
||||
slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err)
|
||||
}
|
||||
}
|
||||
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
|
||||
|
||||
return devices
|
||||
}
|
||||
108
discover/runner_test.go
Normal file
108
discover/runner_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
)
|
||||
|
||||
func init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
|
||||
func TestFilterOverlapByLibrary(t *testing.T) {
|
||||
type testcase struct {
|
||||
name string
|
||||
inp map[string]map[string]map[string]int
|
||||
exp []bool
|
||||
}
|
||||
for _, tc := range []testcase{
|
||||
{
|
||||
name: "empty",
|
||||
inp: map[string]map[string]map[string]int{},
|
||||
exp: []bool{}, // needs deletion
|
||||
},
|
||||
{
|
||||
name: "single no overlap",
|
||||
inp: map[string]map[string]map[string]int{
|
||||
"CUDA": {
|
||||
"cuda_v12": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
exp: []bool{false},
|
||||
},
|
||||
{
|
||||
name: "100% overlap pick 2nd",
|
||||
inp: map[string]map[string]map[string]int{
|
||||
"CUDA": {
|
||||
"cuda_v12": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0,
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1,
|
||||
},
|
||||
"cuda_v13": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2,
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
exp: []bool{true, true, false, false},
|
||||
},
|
||||
{
|
||||
name: "100% overlap pick 1st",
|
||||
inp: map[string]map[string]map[string]int{
|
||||
"CUDA": {
|
||||
"cuda_v13": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0,
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1,
|
||||
},
|
||||
"cuda_v12": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2,
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
exp: []bool{false, false, true, true},
|
||||
},
|
||||
{
|
||||
name: "partial overlap pick older",
|
||||
inp: map[string]map[string]map[string]int{
|
||||
"CUDA": {
|
||||
"cuda_v13": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0,
|
||||
},
|
||||
"cuda_v12": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 1,
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
exp: []bool{true, false, false},
|
||||
},
|
||||
{
|
||||
name: "no overlap",
|
||||
inp: map[string]map[string]map[string]int{
|
||||
"CUDA": {
|
||||
"cuda_v13": {
|
||||
"GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0,
|
||||
},
|
||||
"cuda_v12": {
|
||||
"GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
exp: []bool{false, false},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
needsDelete := make([]bool, len(tc.exp))
|
||||
filterOverlapByLibrary(tc.inp, needsDelete)
|
||||
for i, exp := range tc.exp {
|
||||
if needsDelete[i] != exp {
|
||||
t.Fatalf("expected: %v\ngot: %v", tc.exp, needsDelete)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type memInfo struct {
|
||||
@@ -13,52 +15,6 @@ type memInfo struct {
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"` // TODO split this out for system only
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
memInfo
|
||||
Library string `json:"library,omitempty"`
|
||||
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath []string `json:"lib_path,omitempty"`
|
||||
|
||||
// Extra environment variables specific to the GPU as list of [key,value]
|
||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||
|
||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||
// False indicates FreeMemory can generally be trusted on this GPU
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// TODO other performance capability info to help in scheduling decisions
|
||||
}
|
||||
|
||||
func (gpu GpuInfo) RunnerName() string {
|
||||
if gpu.Variant != "" {
|
||||
return gpu.Library + "_" + gpu.Variant
|
||||
}
|
||||
return gpu.Library
|
||||
}
|
||||
|
||||
type CPUInfo struct {
|
||||
GpuInfo
|
||||
CPUs []CPU
|
||||
}
|
||||
|
||||
// CPU type represents a CPU Package occupying a socket
|
||||
type CPU struct {
|
||||
ID string `cpuinfo:"processor"`
|
||||
@@ -69,115 +25,47 @@ type CPU struct {
|
||||
ThreadCount int
|
||||
}
|
||||
|
||||
type CudaGPUInfo struct {
|
||||
GpuInfo
|
||||
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||
index int //nolint:unused,nolintlint
|
||||
computeMajor int //nolint:unused,nolintlint
|
||||
computeMinor int //nolint:unused,nolintlint
|
||||
}
|
||||
type CudaGPUInfoList []CudaGPUInfo
|
||||
|
||||
type RocmGPUInfo struct {
|
||||
GpuInfo
|
||||
usedFilepath string //nolint:unused,nolintlint
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type RocmGPUInfoList []RocmGPUInfo
|
||||
|
||||
type OneapiGPUInfo struct {
|
||||
GpuInfo
|
||||
driverIndex int //nolint:unused,nolintlint
|
||||
gpuIndex int //nolint:unused,nolintlint
|
||||
}
|
||||
type OneapiGPUInfoList []OneapiGPUInfo
|
||||
|
||||
type GpuInfoList []GpuInfo
|
||||
|
||||
type UnsupportedGPUInfo struct {
|
||||
GpuInfo
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// Split up the set of gpu info's by Library and variant
|
||||
func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||
resp := []GpuInfoList{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
func LogDetails(devices []ml.DeviceInfo) {
|
||||
for _, dev := range devices {
|
||||
var libs []string
|
||||
for _, dir := range dev.LibraryPath {
|
||||
if strings.Contains(dir, filepath.Join("lib", "ollama")) {
|
||||
libs = append(libs, filepath.Base(dir))
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []GpuInfo{info})
|
||||
typeStr := "discrete"
|
||||
if dev.Integrated {
|
||||
typeStr = "iGPU"
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Report the GPU information into the log an Info level
|
||||
func (l GpuInfoList) LogDetails() {
|
||||
for _, g := range l {
|
||||
slog.Info("inference compute",
|
||||
"id", g.ID,
|
||||
"library", g.Library,
|
||||
"variant", g.Variant,
|
||||
"compute", g.Compute,
|
||||
"driver", fmt.Sprintf("%d.%d", g.DriverMajor, g.DriverMinor),
|
||||
"name", g.Name,
|
||||
"total", format.HumanBytes2(g.TotalMemory),
|
||||
"available", format.HumanBytes2(g.FreeMemory),
|
||||
"id", dev.ID,
|
||||
"library", dev.Library,
|
||||
"compute", dev.Compute(),
|
||||
"name", dev.Name,
|
||||
"description", dev.Description,
|
||||
"libdirs", strings.Join(libs, ","),
|
||||
"driver", dev.Driver(),
|
||||
"pci_id", dev.PCIID,
|
||||
"type", typeStr,
|
||||
"total", format.HumanBytes2(dev.TotalMemory),
|
||||
"available", format.HumanBytes2(dev.FreeMemory),
|
||||
)
|
||||
}
|
||||
// CPU inference
|
||||
if len(devices) == 0 {
|
||||
dev, _ := GetCPUMem()
|
||||
slog.Info("inference compute",
|
||||
"id", "cpu",
|
||||
"library", "cpu",
|
||||
"compute", "",
|
||||
"name", "cpu",
|
||||
"description", "cpu",
|
||||
"libdirs", "ollama",
|
||||
"driver", "",
|
||||
"pci_id", "",
|
||||
"type", "",
|
||||
"total", format.HumanBytes2(dev.TotalMemory),
|
||||
"available", format.HumanBytes2(dev.FreeMemory),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by Free Space
|
||||
type ByFreeMemory []GpuInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||
|
||||
type SystemInfo struct {
|
||||
System CPUInfo `json:"system"`
|
||||
GPUs []GpuInfo `json:"gpus"`
|
||||
UnsupportedGPUs []UnsupportedGPUInfo `json:"unsupported_gpus"`
|
||||
DiscoveryErrors []string `json:"discovery_errors"`
|
||||
}
|
||||
|
||||
// Return the optimal number of threads to use for inference
|
||||
func (si SystemInfo) GetOptimalThreadCount() int {
|
||||
if len(si.System.CPUs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
coreCount := 0
|
||||
for _, c := range si.System.CPUs {
|
||||
coreCount += c.CoreCount - c.EfficiencyCoreCount
|
||||
}
|
||||
|
||||
return coreCount
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "metal" ||
|
||||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||
gpu.Library == "rocm"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* [Quickstart](../README.md#quickstart)
|
||||
* [Examples](./examples.md)
|
||||
* [Importing models](./import.md)
|
||||
* [MacOS Documentation](./macos.md)
|
||||
* [Linux Documentation](./linux.md)
|
||||
* [Windows Documentation](./windows.md)
|
||||
* [Docker Documentation](./docker.md)
|
||||
|
||||
247
docs/api.md
247
docs/api.md
@@ -500,21 +500,30 @@ The `message` object has the following fields:
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Tool calling
|
||||
|
||||
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
|
||||
|
||||
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
|
||||
|
||||
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||
|
||||
### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Chat Request (Streaming)
|
||||
#### Chat request (Streaming)
|
||||
|
||||
##### Request
|
||||
|
||||
@@ -569,6 +578,88 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Streaming with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:22:19.184789Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model":"llama3.2",
|
||||
"created_at":"2025-07-07T20:22:19.19314Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 182242375,
|
||||
"load_duration": 41295167,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 24573166,
|
||||
"eval_count": 15,
|
||||
"eval_duration": 115959084
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming)
|
||||
|
||||
##### Request
|
||||
@@ -606,6 +697,74 @@ curl http://localhost:11434/api/chat -d '{
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:32:53.844124Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 3244883583,
|
||||
"load_duration": 2969184542,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 141656333,
|
||||
"eval_count": 18,
|
||||
"eval_duration": 133293625
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Structured outputs)
|
||||
|
||||
##### Request
|
||||
@@ -712,6 +871,87 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (With history, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in Toronto?"
|
||||
},
|
||||
// the message from the model appended to history
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
// the tool call result appended to history
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:43:37.688511Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Toronto is 11°C."
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 890771750,
|
||||
"load_duration": 707634750,
|
||||
"prompt_eval_count": 94,
|
||||
"prompt_eval_duration": 91703208,
|
||||
"eval_count": 11,
|
||||
"eval_duration": 90282125
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (with images)
|
||||
|
||||
##### Request
|
||||
@@ -1353,7 +1593,7 @@ Then there is a series of downloading responses. Until any of the download is co
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "downloading digestname",
|
||||
"status": "pulling digestname",
|
||||
"digest": "digestname",
|
||||
"total": 2142590208,
|
||||
"completed": 241970
|
||||
@@ -1468,6 +1708,7 @@ Advanced parameters:
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
### Examples
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
||||
40
docs/cloud.md
Normal file
40
docs/cloud.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Cloud
|
||||
|
||||
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
|
||||
|
||||
## Cloud Models
|
||||
|
||||
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer.
|
||||
|
||||
Ollama currently supports the following cloud models, with more coming soon:
|
||||
|
||||
- `gpt-oss:20b-cloud`
|
||||
- `gpt-oss:120b-cloud`
|
||||
- `deepseek-v3.1:671b-cloud`
|
||||
- `qwen3-coder:480b-cloud`
|
||||
|
||||
### Get started
|
||||
|
||||
To run a cloud model, open the terminal and run:
|
||||
|
||||
```
|
||||
ollama run gpt-oss:120b-cloud
|
||||
```
|
||||
|
||||
To run cloud models with integrations that work with Ollama, first download the cloud model:
|
||||
|
||||
```
|
||||
ollama pull qwen3-coder:480b-cloud
|
||||
```
|
||||
|
||||
Then sign in to Ollama:
|
||||
|
||||
```
|
||||
ollama signin
|
||||
```
|
||||
|
||||
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
|
||||
|
||||
## Cloud API access
|
||||
|
||||
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).
|
||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||
|
||||
|
||||
## macOS (Apple Silicon)
|
||||
|
||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||
@@ -118,7 +122,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
||||
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
33
docs/faq.md
33
docs/faq.md
@@ -20,9 +20,9 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens for most models. The `gpt-oss` model has a default context window size of 8192 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
This can be overridden in Settings in the Windows and macOS App, or with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
@@ -46,6 +46,8 @@ curl http://localhost:11434/api/generate -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
Setting the context length higher may cause the model to not be able to fit onto the GPU which make the model run more slowly.
|
||||
|
||||
## How can I tell if my model was loaded onto the GPU?
|
||||
|
||||
Use the `ollama ps` command to see what models are currently loaded into memory.
|
||||
@@ -57,8 +59,8 @@ ollama ps
|
||||
> **Output**:
|
||||
>
|
||||
> ```
|
||||
> NAME ID SIZE PROCESSOR UNTIL
|
||||
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
> NAME ID SIZE PROCESSOR CONTEXT UNTIL
|
||||
> gpt-oss:20b 05afbac4bad6 16 GB 100% GPU 8192 4 minutes from now
|
||||
> ```
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -148,9 +150,11 @@ docker build -t ollama-with-ca .
|
||||
docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-with-ca
|
||||
```
|
||||
|
||||
## Does Ollama send my prompts and answers back to ollama.com?
|
||||
## Does Ollama send my prompts and responses back to ollama.com?
|
||||
|
||||
No. Ollama runs locally, and conversation data does not leave your machine.
|
||||
If you're running a model locally, your prompts and responses will always stay on your machine. Ollama Turbo in the App allows you to run your queries on Ollama's servers if you don't have a powerful enough GPU. Web search lets a model query the web, giving you more accurate and up-to-date information. Both Turbo and web search require sending your prompts and responses to Ollama.com. This data is neither logged nor stored.
|
||||
|
||||
If you don't want to see the Turbo and web search options in the app, you can disable them in Settings by turning on Airplane mode. In Airplane mode, all models will run locally, and your prompts and responses will stay on your machine.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
@@ -292,7 +296,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
|
||||
|
||||
## How does Ollama handle concurrent requests?
|
||||
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing.
|
||||
|
||||
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
||||
|
||||
@@ -301,7 +305,7 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
@@ -333,3 +337,16 @@ The currently available K/V cache quantization types are:
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||
|
||||
## How can I stop Ollama from starting when I login to my computer
|
||||
|
||||
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||
|
||||
**Windows**
|
||||
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
|
||||
|
||||
**MacOS Monterey (v12)**
|
||||
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
|
||||
|
||||
**MacOS Ventura (v13) and later**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
|
||||
28
docs/gpu.md
28
docs/gpu.md
@@ -1,21 +1,28 @@
|
||||
# GPU
|
||||
## Nvidia
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
||||
|
||||
Check your compute compatibility to see if your card is supported:
|
||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 11.0 | Jetson | `T4000` `T5000` (Requires driver 580 or newer) |
|
||||
| 10.3 | NVIDIA Professioal | `B300` `GB300` (Requires driver 580 or newer) |
|
||||
| 10.0 | NVIDIA Professioal | `B200` `GB200` (Requires driver 580 or newer) |
|
||||
| 9.0 | NVIDIA | `H200` `H100` `GH200` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
| 8.7 | Jetson | `Orin Nano` `Orin NX` `AGX Orin` |
|
||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
||||
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
|
||||
| 8.0 | NVIDIA | `A100` `A30` |
|
||||
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
|
||||
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
|
||||
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
|
||||
| 7.2 | Jetson | `Xavier NX` `AGX Xavier` (Jetpack 5) |
|
||||
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
|
||||
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
|
||||
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
|
||||
@@ -49,20 +56,23 @@ sudo modprobe nvidia_uvm`
|
||||
Ollama supports the following AMD GPUs:
|
||||
|
||||
### Linux Support
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` |
|
||||
|
||||
### Windows Support
|
||||
With ROCm v6.1, the following GPUs are supported on Windows.
|
||||
With ROCm v6.2, the following GPUs are supported on Windows.
|
||||
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||
|
||||
### Known Workarounds
|
||||
|
||||
- The RX Vega 56 requires `HSA_ENABLE_SDMA=0` to disable SDMA
|
||||
|
||||
### Overrides on Linux
|
||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||
@@ -83,8 +93,6 @@ At this time, the known supported GPU types on linux are the following LLVM Targ
|
||||
This table shows some example GPUs that map to these LLVM targets:
|
||||
| **LLVM Target** | **An Example GPU** |
|
||||
|-----------------|---------------------|
|
||||
| gfx900 | Radeon RX Vega 56 |
|
||||
| gfx906 | Radeon Instinct MI50 |
|
||||
| gfx908 | Radeon Instinct MI100 |
|
||||
| gfx90a | Radeon Instinct MI210 |
|
||||
| gfx940 | Radeon Instinct MI300 |
|
||||
|
||||
@@ -53,6 +53,8 @@ FROM /path/to/safetensors/directory
|
||||
|
||||
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
|
||||
|
||||
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
|
||||
|
||||
Now run the `ollama create` command from the directory where you created the `Modelfile`:
|
||||
|
||||
```shell
|
||||
|
||||
@@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
|
||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||
sudo rm -rf /usr/lib/ollama
|
||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||
```
|
||||
|
||||
@@ -34,7 +35,11 @@ ollama -v
|
||||
|
||||
### AMD GPU install
|
||||
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
If you have an AMD GPU, **also** download and extract the additional ROCm package:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The ROCm tgz contains only AMD dependent libraries. You must extract **both** `ollama-linux-amd64.tgz` and `ollama-linux-amd64-rocm.tgz` into the same location.
|
||||
|
||||
|
||||
```shell
|
||||
curl -L https://ollama.com/download/ollama-linux-amd64-rocm.tgz -o ollama-linux-amd64-rocm.tgz
|
||||
|
||||
42
docs/macos.md
Normal file
42
docs/macos.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Ollama for macOS
|
||||
|
||||
## System Requirements
|
||||
|
||||
* MacOS Sonoma (v14) or newer
|
||||
* Apple M series (CPU and GPU support) or x86 (CPU only)
|
||||
|
||||
|
||||
## Filesystem Requirements
|
||||
|
||||
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
|
||||
|
||||
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
|
||||
|
||||
### Changing Install Location
|
||||
|
||||
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Ollama on MacOS stores files in a few different locations.
|
||||
- `~/.ollama` contains models and configuration
|
||||
- `~/.ollama/logs` contains logs
|
||||
- *app.log* contains most recent logs from the GUI application
|
||||
- *server.log* contains the most recent server logs
|
||||
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
|
||||
|
||||
## Uninstall
|
||||
|
||||
To fully remove Ollama from your system, remove the following files and folders:
|
||||
|
||||
```
|
||||
sudo rm -rf /Applications/Ollama.app
|
||||
sudo rm /usr/local/bin/ollama
|
||||
rm -rf "~/Library/Application Support/Ollama"
|
||||
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
|
||||
rm -rf ~/Library/Caches/com.electron.ollama/
|
||||
rm -rf ~/Library/Caches/ollama
|
||||
rm -rf ~/Library/WebKit/com.electron.ollama
|
||||
rm -rf ~/.ollama
|
||||
```
|
||||
@@ -150,7 +150,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
class FriendList(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@@ -23,7 +23,7 @@ docker logs <container-name>
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
|
||||
@@ -38,26 +38,14 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
```
|
||||
Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library.
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||
```
|
||||
|
||||
You can see what features your CPU has with the following.
|
||||
|
||||
```shell
|
||||
cat /proc/cpuinfo| grep flags | head -1
|
||||
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
|
||||
```
|
||||
|
||||
## Installing older or pre-release versions on Linux
|
||||
@@ -92,12 +80,15 @@ If none of those resolve the problem, gather additional information and file an
|
||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||
|
||||
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
|
||||
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
|
||||
|
||||
|
||||
## AMD GPU Discovery
|
||||
|
||||
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
|
||||
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
|
||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||
|
||||
@@ -30,20 +30,6 @@ To install the Ollama application in a location different than your home directo
|
||||
OllamaSetup.exe /DIR="d:\some\location"
|
||||
```
|
||||
|
||||
### Changing Model Location
|
||||
|
||||
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
|
||||
|
||||
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
|
||||
|
||||
2. Click on _Edit environment variables for your account_.
|
||||
|
||||
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
|
||||
|
||||
4. Click OK/Apply to save.
|
||||
|
||||
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
@@ -82,9 +68,9 @@ If you'd like to install or integrate Ollama as a service, a standalone
|
||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||
same directory. This allows for embedding Ollama in existing applications, or
|
||||
running it as a system service via `ollama serve` with tools such as
|
||||
[NSSM](https://nssm.cc/).
|
||||
same directory. Both zip files are necessary for a complete AMD installation.
|
||||
This allows for embedding Ollama in existing applications, or running it as a
|
||||
system service via `ollama serve` with tools such as [NSSM](https://nssm.cc/).
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||
|
||||
@@ -24,6 +24,9 @@ func Host() *url.URL {
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", s
|
||||
if s == "ollama.com" {
|
||||
scheme, hostport = "https", "ollama.com:443"
|
||||
}
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
@@ -134,8 +137,19 @@ func LoadTimeout() (loadTimeout time.Duration) {
|
||||
return loadTimeout
|
||||
}
|
||||
|
||||
func Bool(k string) func() bool {
|
||||
return func() bool {
|
||||
func Remotes() []string {
|
||||
var r []string
|
||||
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
|
||||
if raw == "" {
|
||||
r = []string{"ollama.com"}
|
||||
} else {
|
||||
r = strings.Split(raw, ",")
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func BoolWithDefault(k string) func(defaultValue bool) bool {
|
||||
return func(defaultValue bool) bool {
|
||||
if s := Var(k); s != "" {
|
||||
b, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
@@ -145,7 +159,14 @@ func Bool(k string) func() bool {
|
||||
return b
|
||||
}
|
||||
|
||||
return false
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
func Bool(k string) func() bool {
|
||||
withDefault := BoolWithDefault(k)
|
||||
return func() bool {
|
||||
return withDefault(false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +187,7 @@ func LogLevel() slog.Level {
|
||||
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -199,6 +220,7 @@ var (
|
||||
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
|
||||
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
|
||||
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
|
||||
VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES")
|
||||
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
|
||||
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
|
||||
)
|
||||
@@ -219,7 +241,7 @@ func Uint(key string, defaultValue uint) func() uint {
|
||||
|
||||
var (
|
||||
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1)
|
||||
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
|
||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||
@@ -252,7 +274,7 @@ type EnvVar struct {
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
@@ -270,6 +292,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
@@ -288,6 +311,7 @@ func AsMap() map[string]EnvVar {
|
||||
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
|
||||
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"}
|
||||
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"}
|
||||
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
|
||||
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
|
||||
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
|
||||
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestHost(t *testing.T) {
|
||||
"https": {"https://1.2.3.4", "https://1.2.3.4:443"},
|
||||
"https port": {"https://1.2.3.4:4321", "https://1.2.3.4:4321"},
|
||||
"proxy path": {"https://example.com/ollama", "https://example.com:443/ollama"},
|
||||
"ollama.com": {"ollama.com", "https://ollama.com:443"},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
|
||||
@@ -10,4 +10,5 @@ type Config interface {
|
||||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
}
|
||||
|
||||
310
fs/ggml/ggml.go
310
fs/ggml/ggml.go
@@ -1,14 +1,17 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
@@ -34,7 +37,8 @@ func (kv KV) Kind() string {
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) FileType() FileType {
|
||||
@@ -53,16 +57,66 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
func (kv KV) HeadCount() []uint64 {
|
||||
headCountDefault := uint32(1)
|
||||
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
|
||||
if len(headCount) == 1 {
|
||||
headCountDefault = headCount[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCount) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCount) {
|
||||
out[i] = uint64(headCountDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCount[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() []uint64 {
|
||||
headCountKVDefault := uint32(1)
|
||||
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
|
||||
if len(headCountKV) == 1 {
|
||||
headCountKVDefault = headCountKV[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCountKV) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCountKV) {
|
||||
out[i] = uint64(headCountKVDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCountKV[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||
if heads := kv.HeadCountMin(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
@@ -70,15 +124,11 @@ func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
@@ -89,45 +139,116 @@ func (kv KV) ChatTemplate() string {
|
||||
return kv.String("tokenizer.chat_template")
|
||||
}
|
||||
|
||||
// ssm architecture parameters
|
||||
|
||||
func (kv KV) SSMConvKernel() uint64 {
|
||||
return uint64(kv.Uint("ssm.conv_kernel"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMInnerSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.inner_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMStateSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.state_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
return keyValue(kv, key, append(defaultValue, "")...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
|
||||
_, max := kv.UintOrArrayValue(key, defaultValue)
|
||||
return max
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
min, _ := kv.UintOrArrayValue(key, defaultValue)
|
||||
return min
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
|
||||
return slices.Min(arrVal), slices.Max(arrVal)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return []uint32{u32}
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
return u32s.values
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
dst := make([]uint32, len(i32s.values))
|
||||
for i, v := range i32s.values {
|
||||
if v < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
|
||||
}
|
||||
dst[i] = uint32(v)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
return []uint32{defaultValue}
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
"gptoss", "gpt-oss",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -143,17 +264,17 @@ type arrayValueTypes interface {
|
||||
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key]; ok {
|
||||
return val.(T)
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
@@ -222,36 +343,37 @@ type Tensor struct {
|
||||
|
||||
func (t Tensor) block() (n int) {
|
||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||
return -1
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
return (TensorType)(t.Kind).BlockSize()
|
||||
return TensorType(t.Kind).BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -324,6 +446,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -423,23 +547,68 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsArr := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
headsKVArr := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
|
||||
// Default for models unless special-cased below. These defaults mirror the
|
||||
// cache usage in llama.cpp under the assumption that models without special
|
||||
// cases below will use the llamarunner and caching will be handled by the
|
||||
// llama.cpp layer.
|
||||
//
|
||||
// This also assumes that a layer without heads or headsKV set is recurrent
|
||||
// which is usually the case. Some models (eg nemotronh) use "blocks" in
|
||||
// place of layers where some are MLP blocks that don't have any cache.
|
||||
// Models like this will need a special case below to be accurately
|
||||
// estimated.
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
kvSizeAttn := uint64(0)
|
||||
kvSizeRecurrent := uint64(0)
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
headsL := headsArr[i]
|
||||
headsKVL := headsKVArr[i]
|
||||
if headsL > 0 && headsKVL > 0 {
|
||||
// full attention layer
|
||||
// NOTE: Assumes uniform values for all attn layers
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
|
||||
kvSizeAttn += kv[i]
|
||||
} else {
|
||||
// recurrent layer
|
||||
ssmDConv := f.KV().SSMConvKernel()
|
||||
ssmDState := f.KV().SSMStateSize()
|
||||
ssmDInner := f.KV().SSMInnerSize()
|
||||
ssmNGroups := f.KV().SSMGroupCount()
|
||||
nEmbdR := uint64(0)
|
||||
if ssmDConv > 0 {
|
||||
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
|
||||
}
|
||||
nEmbdS := ssmDState * ssmDInner
|
||||
|
||||
// recurrent always uses F32 in llama.cpp backend
|
||||
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
|
||||
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
|
||||
|
||||
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
|
||||
kvSizeRecurrent += kv[i]
|
||||
}
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama", "llama4":
|
||||
@@ -504,7 +673,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
case "gemma", "gemma2", "gemma3", "gemma3n":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
@@ -517,6 +686,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
if f.KV().Architecture() == "gemma3n" {
|
||||
fullOffload *= 4
|
||||
partialOffload *= 4
|
||||
}
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
@@ -602,6 +776,22 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
case "gptoss", "gpt-oss":
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
if i%2 == 0 {
|
||||
kv[i] *= (uint64(numParallel)*4096 + batch)
|
||||
} else {
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@@ -676,7 +866,11 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
if cacheType == "" || cacheType == "f16" {
|
||||
return true
|
||||
}
|
||||
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
@@ -686,12 +880,26 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gptoss", "gpt-oss",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
@@ -699,6 +907,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
case "f32":
|
||||
return 4 // f32 (default for recurrent)
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
|
||||
@@ -269,3 +269,33 @@ func TestKeyValue(t *testing.T) {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadCount(t *testing.T) {
|
||||
valuesArray := []int32{1, 5, 3, 4}
|
||||
cases := []struct {
|
||||
kv KV
|
||||
want uint64
|
||||
}{
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
|
||||
},
|
||||
want: uint64(5),
|
||||
},
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": uint32(3),
|
||||
},
|
||||
want: uint64(3),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.kv.HeadCountMax()
|
||||
if got != tt.want {
|
||||
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,7 +509,10 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
return err
|
||||
@@ -528,17 +531,22 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
slices.SortStableFunc(
|
||||
ts,
|
||||
func(a, b *Tensor) int {
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.block(), b.block()),
|
||||
cmp.Compare(a.Name, b.Name),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
@@ -571,7 +579,14 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
if !strings.HasPrefix(k, arch+".") &&
|
||||
!strings.HasPrefix(k, "general.") &&
|
||||
!strings.HasPrefix(k, "adapter.") &&
|
||||
!strings.HasPrefix(k, "tokenizer.") {
|
||||
k = arch + "." + k
|
||||
}
|
||||
|
||||
slog.Debug(k, "type", fmt.Sprintf("%T", v))
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
|
||||
return err
|
||||
@@ -609,6 +624,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
case []bool:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||
case *array[bool]:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
|
||||
@@ -11,24 +11,24 @@ import (
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
}
|
||||
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
@@ -39,7 +39,12 @@ func TestWriteGGUF(t *testing.T) {
|
||||
defer w.Close()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"test.key": "value",
|
||||
"attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -56,21 +61,26 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
"test.key": "value",
|
||||
"test.attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 608,
|
||||
Offset: 800,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
|
||||
@@ -14,9 +14,9 @@ const (
|
||||
FileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
FileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
@@ -97,6 +97,8 @@ func (t FileType) String() string {
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeMXFP4:
|
||||
return "MXFP4"
|
||||
case FileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
@@ -172,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
@@ -187,7 +191,7 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2 // unused by GGML
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
@@ -222,9 +226,10 @@ const (
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
// ParseTensorType parses the provided GGUF tensor type
|
||||
// Only Ollama supported types are considered valid
|
||||
func ParseTensorType(s string) (TensorType, error) {
|
||||
switch s {
|
||||
@@ -260,6 +265,8 @@ func ParseTensorType(s string) (TensorType, error) {
|
||||
return TensorTypeF64, nil
|
||||
case "BF16":
|
||||
return TensorTypeBF16, nil
|
||||
case "MXFP4":
|
||||
return TensorTypeMXFP4, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||
}
|
||||
@@ -312,6 +319,8 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case 4, TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func Open(path string) (f *File, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version != 3 {
|
||||
if f.Version < 2 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
|
||||
4
go.mod
4
go.mod
@@ -25,6 +25,7 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -44,7 +45,6 @@ require (
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -71,7 +71,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
|
||||
544
harmony/harmonyparser.go
Normal file
544
harmony/harmonyparser.go
Normal file
@@ -0,0 +1,544 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
return "LookingForMessageStart"
|
||||
case harmonyParserState_ParsingHeader:
|
||||
return "ParsingHeader"
|
||||
case harmonyParserState_ParsingContent:
|
||||
return "ParsingContent"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type HarmonyParser struct {
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
}
|
||||
|
||||
type HarmonyEvent interface {
|
||||
isHarmonyEvent()
|
||||
}
|
||||
|
||||
type HarmonyEventMessageStart struct{}
|
||||
|
||||
func (HarmonyEventMessageStart) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventHeaderComplete struct {
|
||||
Header HarmonyHeader
|
||||
}
|
||||
|
||||
func (HarmonyEventHeaderComplete) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventContentEmitted struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (HarmonyEventContentEmitted) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventMessageEnd struct{}
|
||||
|
||||
func (HarmonyEventMessageEnd) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyHeader struct {
|
||||
Role string
|
||||
Channel string
|
||||
Recipient string
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
s.lifetimeAcc.WriteString(content)
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var events []HarmonyEvent
|
||||
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
var newEvents []HarmonyEvent
|
||||
newEvents, keepLooping = eat(s)
|
||||
events = append(events, newEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *HarmonyParser) ([]HarmonyEvent, bool) {
|
||||
switch s.state {
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
// does the acc contain the message start tag?
|
||||
if strings.Contains(s.acc.String(), s.MessageStartTag) {
|
||||
// split the acc into the message start tag and the rest
|
||||
split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2)
|
||||
before := split[0]
|
||||
if before != "" {
|
||||
slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String())
|
||||
}
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingHeader
|
||||
return []HarmonyEvent{HarmonyEventMessageStart{}}, true
|
||||
}
|
||||
|
||||
// no match, so we keep accumulating
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingHeader:
|
||||
if strings.Contains(s.acc.String(), s.HeaderEndTag) {
|
||||
split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2)
|
||||
header := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingContent
|
||||
return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true
|
||||
}
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingContent:
|
||||
if strings.Contains(s.acc.String(), s.MessageEndTag) {
|
||||
// if we already have the message end tag, we can emit the content up to it
|
||||
split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2)
|
||||
content := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_LookingForMessageStart
|
||||
events := []HarmonyEvent{}
|
||||
if content != "" {
|
||||
events = append(events, HarmonyEventContentEmitted{Content: content})
|
||||
}
|
||||
events = append(events, HarmonyEventMessageEnd{})
|
||||
return events, true
|
||||
} else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 {
|
||||
// if our suffix contains the start of the message end tag, we can emit
|
||||
// the content up to the start of the message end tag
|
||||
content := s.acc.String()[:len(s.acc.String())-overlapLen]
|
||||
remaining := s.acc.String()[len(s.acc.String())-overlapLen:]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(remaining)
|
||||
// emit the content we know isn't part of the message end tag, and keep
|
||||
// accumulating to disambiguate the rest
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
} else {
|
||||
// no end tag, so it's still normal content that we can immediately emit
|
||||
content := s.acc.String()
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
s.acc.Reset()
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||
harmonyHeader := HarmonyHeader{}
|
||||
|
||||
// if `<|constrain|>` is present, ensure it has a space before it so it gets
|
||||
// parsed as a separate token, even if the model didn't include the space
|
||||
if strings.Contains(raw, "<|constrain|>") {
|
||||
raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1)
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// look for the optional channel tag, which is `<|channel|>` followed by the
|
||||
// channel name, all without any whitespace
|
||||
channelIndex := strings.Index(raw, "<|channel|>")
|
||||
if channelIndex != -1 {
|
||||
before := raw[:channelIndex]
|
||||
after := raw[channelIndex+len("<|channel|>"):]
|
||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||
return unicode.IsSpace(r)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(after)
|
||||
}
|
||||
harmonyHeader.Channel = after[:idx]
|
||||
after = after[idx:]
|
||||
// now we remove the channel tag from the raw string to further process
|
||||
raw = before + after
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// split the header into whitespace-separated tokens
|
||||
tokens := strings.Fields(raw)
|
||||
|
||||
// the first token is treated as the role
|
||||
if len(tokens) == 0 {
|
||||
slog.Error("harmony parser: missing role in header", "header", raw)
|
||||
return harmonyHeader
|
||||
}
|
||||
role := tokens[0]
|
||||
tokens = tokens[1:]
|
||||
// special case: if role starts with to= then it's a tool call
|
||||
if strings.HasPrefix(role, "to=") {
|
||||
harmonyHeader.Recipient = role[3:]
|
||||
harmonyHeader.Role = "tool"
|
||||
} else {
|
||||
harmonyHeader.Role = role
|
||||
}
|
||||
|
||||
// the recipient (if any) can be specified before or after the channel tag, so
|
||||
// we check it at the end once we've already parsed the channel and role
|
||||
if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") {
|
||||
harmonyHeader.Recipient = tokens[0][3:]
|
||||
}
|
||||
|
||||
return harmonyHeader
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// harmonyMessageState represents the current state of message processing
|
||||
type harmonyMessageState int
|
||||
|
||||
const (
|
||||
harmonyMessageState_Normal harmonyMessageState = iota
|
||||
harmonyMessageState_Thinking
|
||||
harmonyMessageState_ToolCalling
|
||||
)
|
||||
|
||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
toolAccumulator *HarmonyToolCallAccumulator
|
||||
convertedTools map[string]struct{}
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
HarmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
convertedTools: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case "final":
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
thinkingSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_ToolCalling {
|
||||
toolContentSb.WriteString(event.Content)
|
||||
}
|
||||
case HarmonyEventMessageEnd:
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
}
|
||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
harmonyToolCallState_Normal harmonyToolCallState = iota
|
||||
harmonyToolCallState_ToolCalling
|
||||
)
|
||||
|
||||
type HarmonyToolCallAccumulator struct {
|
||||
state harmonyToolCallState
|
||||
acc strings.Builder
|
||||
currentToolName *string
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) {
|
||||
a.currentToolName = &toolName
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Add(content string) {
|
||||
a.acc.WriteString(content)
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||
str := a.acc.String()
|
||||
a.state = harmonyToolCallState_Normal
|
||||
a.acc.Reset()
|
||||
return a.currentToolName, str
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
|
||||
// FunctionNameMap maps a user-specified function name to a valid function
|
||||
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||
// transform user-specified function names, which might contain characters that
|
||||
// are not allowed in TypeScript identifiers
|
||||
type FunctionNameMap struct {
|
||||
userToHarmony map[string]string
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
harmonyToUser: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the handler with tools and optional last message
|
||||
// Implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
// Initialize the harmony parser
|
||||
if h.HarmonyParser == nil {
|
||||
h.HarmonyParser = &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
}
|
||||
|
||||
// Handle prefill for chat mode
|
||||
if lastMessage != nil {
|
||||
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
} else {
|
||||
h.HarmonyParser.AddImplicitStart()
|
||||
}
|
||||
|
||||
// Initialize tool accumulator
|
||||
h.toolAccumulator = h.CreateToolParser()
|
||||
|
||||
// Process tools and return renamed versions
|
||||
if len(tools) == 0 {
|
||||
return tools
|
||||
}
|
||||
|
||||
processedTools := make([]api.Tool, len(tools))
|
||||
copy(processedTools, tools)
|
||||
for i, tool := range processedTools {
|
||||
if tool.Function.Name != "" {
|
||||
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
h.convertedTools[tool.Function.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
return processedTools
|
||||
}
|
||||
|
||||
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
|
||||
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
|
||||
if toolContent != "" {
|
||||
h.toolAccumulator.Add(toolContent)
|
||||
}
|
||||
|
||||
// tool calls always happen one at a time, and always at the end of a message,
|
||||
// so for simplicity we defer parsing them until we know we're done
|
||||
if done {
|
||||
toolName, raw := h.toolAccumulator.Drain()
|
||||
if toolName != nil {
|
||||
name := strings.TrimPrefix(*toolName, "functions.")
|
||||
name = h.FunctionNameMap.OriginalFromConverted(name)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
|
||||
}
|
||||
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
|
||||
}
|
||||
}
|
||||
|
||||
return content, thinking, calls, nil
|
||||
}
|
||||
|
||||
// HasToolSupport implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
// built-in functions should not be renamed
|
||||
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
|
||||
harmonyFunctionName = userFunctionName
|
||||
}
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||
// the conversion process is not reversible without the appropriate state
|
||||
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||
return userFunctionName
|
||||
}
|
||||
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||
// fallback to the original function name if we can't find a mapping
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// convertToValidChars converts a user-specified function name to a valid
|
||||
// TypeScript identifier.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// - This doesn't restrict reserved TypeScript keywords.
|
||||
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||
// identifiers these models were trained on, so in the end we might want to
|
||||
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||
mapper := func(r rune) rune {
|
||||
// first, replace certain characters with underscores
|
||||
if r == ' ' || r == '-' || r == '.' {
|
||||
return '_'
|
||||
}
|
||||
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
return r
|
||||
}
|
||||
|
||||
// finally, remove any other characters
|
||||
return -1
|
||||
}
|
||||
candidate := strings.Map(mapper, userFunctionName)
|
||||
|
||||
// set a default name if we end up with nothing left
|
||||
if candidate == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
// if the candidate starts with a number, prepend an underscore to make it a
|
||||
// valid identifier
|
||||
if unicode.IsDigit(rune(candidate[0])) {
|
||||
candidate = "_" + candidate
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||
candidate := originalCandidate
|
||||
|
||||
// Check for dupes, and if so, add a number to the end.
|
||||
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||
count := 2
|
||||
for {
|
||||
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||
break
|
||||
}
|
||||
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||
count++
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
538
harmony/harmonyparser_test.go
Normal file
538
harmony/harmonyparser_test.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
}{
|
||||
{
|
||||
in: "assistant<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "assistant<|channel|>analysis to=functions.get_weather",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// special case where the role is replaced by the recipient (matches reference code)
|
||||
{
|
||||
in: "to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "tool",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// extra token after the recipient is ignored
|
||||
{
|
||||
in: "assistant to=functions.get_weather abc<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient after channel tag
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient before channel tag
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space, different order
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
header := parser.parseHeader(tt.in)
|
||||
|
||||
if header.Role != tt.wantRole {
|
||||
t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole)
|
||||
}
|
||||
if header.Channel != tt.wantChannel {
|
||||
t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel)
|
||||
}
|
||||
if header.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserHeaderEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
implicitStart bool
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "user",
|
||||
wantChannel: "",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
implicitStart: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if len(gotEvents) == 0 {
|
||||
t.Errorf("case %d: got no events, want at least one", i)
|
||||
}
|
||||
|
||||
var firstHeaderEvent *HarmonyEventHeaderComplete
|
||||
// print events
|
||||
for _, event := range gotEvents {
|
||||
fmt.Printf("event: %+v\n", event)
|
||||
}
|
||||
for _, event := range gotEvents {
|
||||
if event, ok := event.(HarmonyEventHeaderComplete); ok {
|
||||
firstHeaderEvent = &event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstHeaderEvent == nil {
|
||||
t.Errorf("case %d: got no header complete event, want one", i)
|
||||
continue
|
||||
}
|
||||
gotHeader := firstHeaderEvent.Header
|
||||
if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserNonStreaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
implicitStart bool
|
||||
wantEvents []HarmonyEvent
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "What is 2 + 2?"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "The answer is 4"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
HarmonyEventContentEmitted{Content: "Computing..."},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi!"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>Thinking about the request<|end|>",
|
||||
implicitStart: true,
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if !reflect.DeepEqual(gotEvents, tt.wantEvents) {
|
||||
t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []HarmonyEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
implicitStart bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed character by character",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "|",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "start|>u",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "ser<|mess",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "age|>Hi",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}},
|
||||
},
|
||||
{
|
||||
input: "<|e",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nd|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|chan",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nel|>analysis",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}},
|
||||
},
|
||||
{
|
||||
input: "...",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient (receipient before channel)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "implicit start with channel",
|
||||
implicitStart: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|channel|>thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Processing request",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple messages streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>Hello<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "<|start|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "assistant<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Hi!",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>system<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tag that looks like end but isn't",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>test<|e",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "test"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "xample|>more",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tc.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
gotEvents := parser.AddContent(step.input)
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||
// handle any saving (and therefore no dupe handling)
|
||||
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||
{name: "leading number", in: "123", want: "_123"},
|
||||
{name: "$ allowed", in: "$", want: "$"},
|
||||
// show that we allow weird unicode letter characters, though we might want
|
||||
// to convert them to their closest ASCII equivalents in the future
|
||||
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewFunctionNameMap()
|
||||
got := parser.convertToValidChars(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
parser := NewFunctionNameMap()
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for j, in := range tt.in {
|
||||
got := parser.ConvertAndAdd(in)
|
||||
want := tt.want[j]
|
||||
if got != want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
// check that the maps are correct
|
||||
if parser.userToHarmony[in] != want {
|
||||
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||
}
|
||||
if parser.harmonyToUser[want] != in {
|
||||
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
|
||||
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
||||
@@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue? be brief",
|
||||
Prompt: blueSkyPrompt,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
@@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) {
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
@@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) {
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "why is the sky blue? be brief",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Options: map[string]interface{}{
|
||||
@@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
@@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) {
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for chat")
|
||||
@@ -390,7 +388,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Model: libraryEmbedModels[0],
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
@@ -410,3 +408,99 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "qwen3:0.6b"
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,23 +11,27 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
@@ -35,10 +39,15 @@ func TestUnicode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
req := api.ChatRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -50,17 +59,39 @@ func TestUnicode(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
req := api.ChatRequest{
|
||||
Model: "gemma2:2b",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Output some smily face emoji",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -69,8 +100,10 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestUnicodeModelDir(t *testing.T) {
|
||||
@@ -84,7 +117,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
@@ -93,14 +128,19 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
@@ -4,257 +4,185 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "llama3.2:1b",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
defer cancel()
|
||||
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
||||
func TestConcurrentChat(t *testing.T) {
|
||||
// Assumes all requests have the same model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := int(envconfig.NumParallel() + 1)
|
||||
iterLimit := 3
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
|
||||
}
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
// Note: CPU based inference can crawl so don't give up too quickly
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < 4*format.GibiByte {
|
||||
reqLimit = min(reqLimit, 2)
|
||||
iterLimit = 2
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
|
||||
slog.Info("loading", "model", req[0].Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(reqLimit)
|
||||
for i := 0; i < reqLimit; i++ {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(req)
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
|
||||
DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
||||
// This test will always load at least 2 models even on CPU based systems
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if s == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
s = "0"
|
||||
}
|
||||
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if maxVram < 2*format.GibiByte {
|
||||
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||
|
||||
// All models compatible with ollama-engine
|
||||
smallModels := []string{
|
||||
"llama3.2:1b",
|
||||
"qwen3:0.6b",
|
||||
"gemma2:2b",
|
||||
"deepseek-r1:1.5b", // qwen2 arch
|
||||
"gemma3:270m",
|
||||
}
|
||||
mediumModels := []string{
|
||||
"llama3.2:3b", // ~3.4G
|
||||
"qwen3:8b", // ~6.6G
|
||||
"gpt-oss:20b", // ~15G
|
||||
"deepseek-r1:7b", // ~5.6G
|
||||
"gemma3:4b", // ~5.8G
|
||||
"gemma2:9b", // ~8.1G
|
||||
}
|
||||
|
||||
type model struct {
|
||||
name string
|
||||
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||
}
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "llama3.2:1b",
|
||||
size: 2876 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
size: 2616 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:2b",
|
||||
size: 2364 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "stable-code:3b",
|
||||
size: 2608 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "starcoder2:3b",
|
||||
size: 2166 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
mediumModels := []model{
|
||||
{
|
||||
name: "llama2",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "orca-mini:7b",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "dolphin-mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:7b",
|
||||
size: 5000 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "codellama:7b",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
|
||||
// These seem to be too slow to be useful...
|
||||
// largeModels := []model{
|
||||
// {
|
||||
// name: "llama2:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "codellama:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "orca-mini:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "gemma:7b",
|
||||
// size: 5000 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "starcoder2:15b",
|
||||
// size: 9100 * format.MebiByte,
|
||||
// },
|
||||
// }
|
||||
|
||||
var chosenModels []model
|
||||
var chosenModels []string
|
||||
switch {
|
||||
case maxVram < 10000*format.MebiByte:
|
||||
slog.Info("selecting small models")
|
||||
chosenModels = smallModels
|
||||
// case maxVram < 30000*format.MebiByte:
|
||||
default:
|
||||
slog.Info("selecting medium models")
|
||||
chosenModels = mediumModels
|
||||
// default:
|
||||
// slog.Info("selecting large models")
|
||||
// chosenModels = largeModels
|
||||
}
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
|
||||
for i := range req {
|
||||
if i > len(chosenModels) {
|
||||
break
|
||||
}
|
||||
req[i].Model = chosenModels[i].name
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, r := range req {
|
||||
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||
for _, model := range chosenModels {
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
|
||||
for i := 0; i < len(req); i++ {
|
||||
// Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long
|
||||
if i > 1 && consumed > maxVram {
|
||||
slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
break
|
||||
// Determine how many models we can load in parallel before we exceed VRAM
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
chooseModels:
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
consumed += chosenModels[i].size
|
||||
slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
targetLoadCount++
|
||||
if i > 0 {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
if len(models.Models) < targetLoadCount {
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
}
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
// TODO consider retrying the medium models
|
||||
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
|
||||
}
|
||||
|
||||
r := rand.New(rand.NewSource(0))
|
||||
var wg sync.WaitGroup
|
||||
for i := range targetLoadCount {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
reqs, resps := ChatRequests()
|
||||
for j := 0; j < 3; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
|
||||
DoChat(ctx, t, client, reqs[k], resps[k], initialTimeout, streamTimeout)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
time.Sleep(10 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -265,7 +193,21 @@ func TestMultiModelStress(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
slog.Info("loaded model snapshot", "model", m)
|
||||
var procStr string
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
procStr = "100% CPU"
|
||||
case m.SizeVRAM == m.Size:
|
||||
procStr = "100% GPU"
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
procStr = "Unknown"
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||
procStr = fmt.Sprintf("%d%%/%d%%", int(cpuPercent), int(100-cpuPercent))
|
||||
}
|
||||
|
||||
slog.Info("loaded model snapshot", "model", m.Name, "CPU/GPU", procStr, "expires", format.HumanTime(m.ExpiresAt, "Never"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,9 +21,14 @@ func TestLongInputContext(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -34,7 +41,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -46,9 +53,14 @@ func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Write me a story in english with a lot of emojis",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -61,5 +73,212 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
modelName := "gpt-oss:20b"
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", modelName)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelName, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, modelName)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelName
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
c := DoGenerate(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
|
||||
req[k].Context = c
|
||||
req[k].Prompt = "tell me more!"
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: rainbowPrompt,
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"num_ctx": 16384,
|
||||
},
|
||||
}
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
|
||||
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
|
||||
for i := 0; i < len(rainbowFollowups); i++ {
|
||||
req.Prompt = rainbowFollowups[i]
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelChatWithHistory(t *testing.T) {
|
||||
modelName := "gpt-oss:20b"
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelName)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelName, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, modelName)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelName
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"num_ctx": 16384,
|
||||
},
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: rainbowPrompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
|
||||
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
|
||||
for i := 0; i < len(rainbowFollowups); i++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
req.Messages = append(req.Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: rainbowFollowups[i]},
|
||||
)
|
||||
|
||||
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
@@ -38,14 +39,14 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
@@ -73,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
@@ -111,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
@@ -155,93 +154,148 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
type testReq struct {
|
||||
Name string
|
||||
Request api.EmbedRequest
|
||||
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqs := []testReq{
|
||||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*api.EmbedResponse, error)
|
||||
}{
|
||||
{
|
||||
Name: "Target Truncation",
|
||||
Request: api.EmbedRequest{
|
||||
name: "target truncation",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Default Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Explicit Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
name: "default truncate",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "explicit truncate",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boundary truncation",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||
Options: map[string]any{"num_ctx": 16},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
res[req.Name] = response
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request and truncate true request to be the same")
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
response, err := client.Embeddings(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return client.Embeddings(ctx, &req)
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
75
integration/library_models_test.go
Normal file
75
integration/library_models_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
//go:build integration && library
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// First run of this scenario on a target system will take a long time to download
|
||||
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
|
||||
func TestLibraryModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||
|
||||
chatModels := libraryChatModels
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if targetArch != "" {
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
if arch != targetArch {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
}
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := blueSkyExpected
|
||||
// Special cases
|
||||
if model == "duckdb-nsql" {
|
||||
anyResp = []string{"select", "from"}
|
||||
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
|
||||
anyResp = []string{"yes", "no", "safe", "unsafe"}
|
||||
} else if model == "openthinker" {
|
||||
anyResp = []string{"plugin", "im_sep", "components", "function call"}
|
||||
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
|
||||
req.Messages[0].Content = "def fibonacci():"
|
||||
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
|
||||
}
|
||||
DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVisionModels(t *testing.T) {
|
||||
@@ -32,18 +31,25 @@ func TestVisionModels(t *testing.T) {
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.ChatRequest{
|
||||
Model: v.model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "what does the text in this image say?",
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
@@ -52,9 +58,18 @@ func TestVisionModels(t *testing.T) {
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
|
||||
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -62,7 +77,9 @@ func TestVisionModels(t *testing.T) {
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
@@ -84,7 +101,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
@@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
@@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
switch {
|
||||
case genErr == nil:
|
||||
successCount++
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||
}
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
@@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%d request failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
@@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedwg.Wait()
|
||||
|
||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
if resetByPeerCount != 0 {
|
||||
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||
}
|
||||
if busyCount == 0 {
|
||||
t.Fatalf("no requests hit busy error but some should have")
|
||||
}
|
||||
if canceledCount > 0 {
|
||||
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,36 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
chatModels = []string{
|
||||
"granite3-moe:latest",
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"command-r:latest",
|
||||
"gemma2:latest",
|
||||
"gemma:latest",
|
||||
"internlm2:latest",
|
||||
"phi3.5:latest",
|
||||
"phi3:latest",
|
||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
"falcon2:latest",
|
||||
"minicpm-v:latest",
|
||||
"mistral:latest",
|
||||
"orca-mini:latest",
|
||||
"llama2:latest",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
func TestModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -68,6 +39,13 @@ func TestModelsGenerate(t *testing.T) {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
@@ -87,17 +65,41 @@ func TestModelsGenerate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
slog.Info("loading", "model", model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, model)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 40 * time.Second
|
||||
}
|
||||
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, blueSkyExpected, initialTimeout, streamTimeout)
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -151,8 +153,9 @@ func TestModelsEmbed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
req := api.EmbeddingRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
@@ -162,6 +165,10 @@ func TestModelsEmbed(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
|
||||
281
integration/model_perf_test.go
Normal file
281
integration/model_perf_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
//go:build integration && perf
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
// Models that don't work reliably with the large context prompt in this test case
|
||||
longContextFlakes = []string{
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"falcon:latest", // 2k model
|
||||
"falcon2:latest", // 2k model
|
||||
"minicpm-v:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
// Note: this test case can take a long time to run, particularly on models with
|
||||
// large contexts. Run with -timeout set to a large value to get reasonable coverage
|
||||
// Example usage:
|
||||
//
|
||||
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
|
||||
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||
func TestModelsPerf(t *testing.T) {
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
doModelPerfTest(t, ollamaEngineChatModels)
|
||||
} else {
|
||||
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLibraryModelsPerf(t *testing.T) {
|
||||
doModelPerfTest(t, libraryChatModels)
|
||||
}
|
||||
|
||||
func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test data file: %s", err)
|
||||
}
|
||||
longPrompt := "summarize the following: " + string(data)
|
||||
|
||||
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||
|
||||
for _, model := range chatModels {
|
||||
if !strings.Contains(model, ":") {
|
||||
model = model + ":latest"
|
||||
}
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
var maxContext int
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
t.Fatalf("show failed: %s", err)
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||
if targetArch != "" && arch != targetArch {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
// For these tests we want to exercise a some amount of overflow on the CPU
|
||||
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info("scneario", "model", model, "max_context", maxContext)
|
||||
loaded := false
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
if loaded {
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}
|
||||
}()
|
||||
|
||||
// Some models don't handle the long context data well so skip them to avoid flaky test results
|
||||
longContextFlake := false
|
||||
for _, flake := range longContextFlakes {
|
||||
if model == flake {
|
||||
longContextFlake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// iterate through a few context sizes for coverage without excessive runtime
|
||||
var contexts []int
|
||||
keepGoing := true
|
||||
if maxContext > 16384 {
|
||||
contexts = []int{4096, 8192, 16384, maxContext}
|
||||
} else if maxContext > 8192 {
|
||||
contexts = []int{4096, 8192, maxContext}
|
||||
} else if maxContext > 4096 {
|
||||
contexts = []int{4096, maxContext}
|
||||
} else if maxContext > 0 {
|
||||
contexts = []int{maxContext}
|
||||
} else {
|
||||
t.Fatal("unknown max context size")
|
||||
}
|
||||
for _, numCtx := range contexts {
|
||||
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
|
||||
break
|
||||
}
|
||||
skipLongPrompt := false
|
||||
|
||||
// Workaround bug 11172 temporarily...
|
||||
maxPrompt := longPrompt
|
||||
// If we fill the context too full with the prompt, many models
|
||||
// quickly hit context shifting and go bad.
|
||||
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
|
||||
maxPrompt = maxPrompt[:numCtx*2]
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
prompt string
|
||||
anyResp []string
|
||||
}{
|
||||
{"blue_sky", blueSkyPrompt, blueSkyExpected},
|
||||
{"max", maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
|
||||
}
|
||||
var gpuPercent int
|
||||
for _, tc := range testCases {
|
||||
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
|
||||
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||
continue
|
||||
}
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: tc.prompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": numCtx,
|
||||
},
|
||||
}
|
||||
atLeastOne := false
|
||||
var resp api.ChatResponse
|
||||
|
||||
stream := false
|
||||
req.Stream = &stream
|
||||
|
||||
// Avoid potentially getting stuck indefinitely
|
||||
limit := 5 * time.Minute
|
||||
genCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(limit),
|
||||
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error {
|
||||
resp = rsp
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Avoid excessive test runs, but don't consider a failure with massive context
|
||||
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
|
||||
slog.Warn("max context was taking too long, skipping", "error", err)
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
continue
|
||||
}
|
||||
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
|
||||
}
|
||||
loaded = true
|
||||
for _, expResp := range tc.anyResp {
|
||||
if strings.Contains(strings.ToLower(resp.Message.Content), expResp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content)
|
||||
}
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list running models", "error", err)
|
||||
continue
|
||||
}
|
||||
if len(models.Models) > 1 {
|
||||
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == model {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("Model fully loaded into CPU")
|
||||
gpuPercent = 0
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
} else if m.SizeVRAM == m.Size {
|
||||
slog.Info("Model fully loaded into GPU")
|
||||
gpuPercent = 100
|
||||
} else {
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
|
||||
keepGoing = false
|
||||
|
||||
// Heuristic to avoid excessive test run time
|
||||
if gpuPercent < 90 {
|
||||
skipLongPrompt = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
prefillTimePerToken := float64(resp.PromptEvalDuration.Nanoseconds()) / float64(resp.PromptEvalCount)
|
||||
prefillTokensPerSec := float64(resp.PromptEvalCount) / (float64(resp.PromptEvalDuration.Nanoseconds()) + 1e-12) * 1e9
|
||||
fmt.Fprintf(os.Stderr, "BenchmarkModel/name=%s-%s/%d/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||
model, tc.name, numCtx, "prefill", resp.PromptEvalCount, prefillTimePerToken, prefillTokensPerSec)
|
||||
|
||||
evalTimePerToken := float64(resp.EvalDuration.Nanoseconds()) / float64(resp.EvalCount)
|
||||
evalTokensPerSec := float64(resp.EvalCount) / (float64(resp.EvalDuration.Nanoseconds()) + 1e-12) * 1e9
|
||||
fmt.Fprintf(os.Stderr, "BenchmarkModel/name=%s-%s/%d/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||
model, tc.name, numCtx, "generate", resp.EvalCount, evalTimePerToken, evalTokensPerSec)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "BenchmarkMode/name=%s-%s/%d 1 %d ns/request\n",
|
||||
model, tc.name, numCtx, resp.TotalDuration.Nanoseconds())
|
||||
fmt.Fprintf(os.Stderr, "BenchmarkMode/name=%s-%s/%d/step=%s 1 %d ns/request\n",
|
||||
model, tc.name, numCtx, "load", resp.LoadDuration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -74,9 +74,14 @@ func TestQuantization(t *testing.T) {
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: "why is the sky blue?",
|
||||
chatReq := api.ChatRequest{
|
||||
Model: newName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
@@ -88,14 +93,13 @@ func TestQuantization(t *testing.T) {
|
||||
|
||||
// Some smaller quantizations can cause models to have poor quality
|
||||
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
||||
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
chatfn := func(response api.ChatResponse) error {
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
atLeastOne = true
|
||||
t.Log(fullResp)
|
||||
@@ -109,14 +113,14 @@ func TestQuantization(t *testing.T) {
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(reqCtx, &genReq, genfn)
|
||||
genErr = client.Chat(reqCtx, &chatReq, chatfn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if genErr != nil && !atLeastOne {
|
||||
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
|
||||
25
integration/testdata/embed.json
vendored
25
integration/testdata/embed.json
vendored
File diff suppressed because one or more lines are too long
124456
integration/testdata/shakespeare.txt
vendored
Normal file
124456
integration/testdata/shakespeare.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,11 +9,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -23,17 +25,263 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
smol = "llama3.2:1b"
|
||||
var (
|
||||
smol = "llama3.2:1b"
|
||||
stream = false
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
var (
|
||||
started = time.Now()
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
"mistral-small3.2:latest",
|
||||
"deepseek-r1:1.5b",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen2.5vl:3b",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:1.7b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"gemma3:1b",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"gemma2:latest",
|
||||
"minicpm-v:latest", // arch=qwen2
|
||||
"granite-code:latest", // arch=llama
|
||||
}
|
||||
llamaRunnerChatModels = []string{
|
||||
"mistral:latest",
|
||||
"falcon3:latest",
|
||||
"granite3-moe:latest",
|
||||
"command-r:latest",
|
||||
"nemotron-mini:latest",
|
||||
"phi3.5:latest",
|
||||
"solar-pro:latest",
|
||||
"internlm2:latest",
|
||||
"codellama:latest", // arch=llama
|
||||
"phi3:latest",
|
||||
"falcon2:latest",
|
||||
"gemma:latest",
|
||||
"llama2:latest",
|
||||
"nous-hermes:latest",
|
||||
"orca-mini:latest",
|
||||
"qwen:latest",
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
}
|
||||
|
||||
// Some library models are quite large - ensure large VRAM and sufficient disk space
|
||||
// before running scenarios based on this set
|
||||
libraryChatModels = []string{
|
||||
"alfred",
|
||||
"athene-v2",
|
||||
"aya-expanse",
|
||||
"aya",
|
||||
"bakllava",
|
||||
"bespoke-minicheck",
|
||||
"codebooga",
|
||||
"codegeex4",
|
||||
"codegemma",
|
||||
"codellama",
|
||||
"codeqwen",
|
||||
"codestral",
|
||||
"codeup",
|
||||
"cogito",
|
||||
"command-a",
|
||||
"command-r-plus",
|
||||
"command-r",
|
||||
"command-r7b-arabic",
|
||||
"command-r7b",
|
||||
"dbrx",
|
||||
"deepcoder",
|
||||
"deepscaler",
|
||||
"deepseek-coder-v2",
|
||||
"deepseek-coder",
|
||||
"deepseek-llm",
|
||||
"deepseek-r1",
|
||||
// "deepseek-v2.5", // requires 155 GB VRAM
|
||||
"deepseek-v2",
|
||||
// "deepseek-v3", // requires 482 GB VRAM
|
||||
"devstral",
|
||||
"dolphin-llama3",
|
||||
"dolphin-mistral",
|
||||
"dolphin-mixtral",
|
||||
"dolphin-phi",
|
||||
"dolphin3",
|
||||
"dolphincoder",
|
||||
"duckdb-nsql",
|
||||
"everythinglm",
|
||||
"exaone-deep",
|
||||
"exaone3.5",
|
||||
"falcon",
|
||||
"falcon2",
|
||||
"falcon3",
|
||||
"firefunction-v2",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"glm4",
|
||||
"goliath",
|
||||
"gpt-oss:20b",
|
||||
"granite-code",
|
||||
"granite3-dense",
|
||||
"granite3-guardian",
|
||||
"granite3-moe",
|
||||
"granite3.1-dense",
|
||||
"granite3.1-moe",
|
||||
"granite3.2-vision",
|
||||
"granite3.2",
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
"llama2-uncensored",
|
||||
"llama2",
|
||||
"llama3-chatqa",
|
||||
"llama3-gradient",
|
||||
"llama3-groq-tool-use",
|
||||
"llama3.1",
|
||||
"llama3.2-vision",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"llama3",
|
||||
"llama4",
|
||||
"llava-llama3",
|
||||
"llava-phi3",
|
||||
"llava",
|
||||
"magicoder",
|
||||
"magistral",
|
||||
"marco-o1",
|
||||
"mathstral",
|
||||
"meditron",
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
"mistral-small",
|
||||
"mistral-small3.1",
|
||||
"mistral-small3.2",
|
||||
"mistral",
|
||||
"mistrallite",
|
||||
"mixtral",
|
||||
"moondream",
|
||||
"nemotron-mini",
|
||||
"nemotron",
|
||||
"neural-chat",
|
||||
"nexusraven",
|
||||
"notus",
|
||||
"nous-hermes",
|
||||
"nous-hermes2-mixtral",
|
||||
"nous-hermes2",
|
||||
"nuextract",
|
||||
"olmo2",
|
||||
"open-orca-platypus2",
|
||||
"openchat",
|
||||
"opencoder",
|
||||
"openhermes",
|
||||
"openthinker",
|
||||
"orca-mini",
|
||||
"orca2",
|
||||
// "phi", // unreliable
|
||||
"phi3.5",
|
||||
"phi3",
|
||||
"phi4-mini-reasoning",
|
||||
"phi4-mini",
|
||||
"phi4-reasoning",
|
||||
"phi4",
|
||||
"phind-codellama",
|
||||
"qwen",
|
||||
"qwen2-math",
|
||||
"qwen2.5-coder",
|
||||
"qwen2.5",
|
||||
"qwen2.5vl",
|
||||
"qwen2",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"qwq",
|
||||
"r1-1776",
|
||||
"reader-lm",
|
||||
"reflection",
|
||||
"sailor2",
|
||||
"samantha-mistral",
|
||||
"shieldgemma",
|
||||
"smallthinker",
|
||||
"smollm",
|
||||
"smollm2",
|
||||
"solar-pro",
|
||||
"solar",
|
||||
"sqlcoder",
|
||||
"stable-beluga",
|
||||
"stable-code",
|
||||
"stablelm-zephyr",
|
||||
"stablelm2",
|
||||
"starcoder",
|
||||
"starcoder2",
|
||||
"starling-lm",
|
||||
"tinydolphin",
|
||||
"tinyllama",
|
||||
"tulu3",
|
||||
"vicuna",
|
||||
"wizard-math",
|
||||
"wizard-vicuna-uncensored",
|
||||
"wizard-vicuna",
|
||||
"wizardcoder",
|
||||
"wizardlm-uncensored",
|
||||
"wizardlm2",
|
||||
"xwinlm",
|
||||
"yarn-llama2",
|
||||
"yarn-mistral",
|
||||
"yi-coder",
|
||||
"yi",
|
||||
"zephyr",
|
||||
}
|
||||
libraryEmbedModels = []string{
|
||||
"all-minilm",
|
||||
"bge-large",
|
||||
"bge-m3",
|
||||
"granite-embedding",
|
||||
"mxbai-embed-large",
|
||||
"nomic-embed-text",
|
||||
"paraphrase-multilingual",
|
||||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
|
||||
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
|
||||
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
|
||||
|
||||
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
|
||||
rainbowFollowups = []string{
|
||||
"Explain the physics involved in them. Be breif in your reply",
|
||||
"Explain the chemistry involved in them. Be breif in your reply",
|
||||
"What are common myths related to them? Be brief in your reply",
|
||||
"Can they form if there is no rain? Be breif in your reply",
|
||||
"Can they form if there are no clouds? Be breif in your reply",
|
||||
"Do they happen on other planets? Be brief in your reply",
|
||||
}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"}
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
|
||||
if custom != "" {
|
||||
slog.Info("setting default test model to " + custom)
|
||||
smol = custom
|
||||
}
|
||||
}
|
||||
|
||||
func FindPort() string {
|
||||
@@ -89,6 +337,7 @@ func GetTestEndpoint() (*api.Client, string) {
|
||||
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
var serverLogFile string
|
||||
|
||||
func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
@@ -115,8 +364,9 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
t.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
}
|
||||
|
||||
logDir := t.TempDir()
|
||||
slog.Info("starting server", "url", ollamaHost)
|
||||
done, err := lifecycle.SpawnServer(ctx, "../ollama")
|
||||
done, err := SpawnServer(ctx, "../ollama", logDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
@@ -139,6 +389,36 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
fp, err := os.CreateTemp(logDir, "ollama-server-*.log")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
serverLogFile = fp.Name()
|
||||
|
||||
cmd := exec.CommandContext(ctx, command, "serve")
|
||||
cmd.Stderr = fp
|
||||
cmd.Stdout = fp
|
||||
|
||||
go func() {
|
||||
slog.Info("starting server...")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// "signal: killed" expected
|
||||
if !strings.Contains(err.Error(), "signal") {
|
||||
slog.Info("failed to run server", "error", err)
|
||||
}
|
||||
}
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
slog.Info("server exited")
|
||||
done <- code
|
||||
}()
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
@@ -199,58 +479,74 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
client, testEndpoint := GetTestEndpoint()
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
||||
}
|
||||
// Make sure server is online and healthy before returning
|
||||
listCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(120*time.Second),
|
||||
fmt.Errorf("list models took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
models, err := client.ListRunning(listCtx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(models.Models) > 0 {
|
||||
names := make([]string, len(models.Models))
|
||||
for i, m := range models.Models {
|
||||
names[i] = m.Name
|
||||
}
|
||||
slog.Info("currently loaded", "models", names)
|
||||
}
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
fp, err := os.Open(serverLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to open server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to read server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err := os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
var context []int
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// fmt.Print(".")
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
if len(response.Context) > 0 {
|
||||
context = response.Context
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -263,6 +559,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -271,21 +583,23 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return context
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for generate")
|
||||
verify()
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
@@ -294,65 +608,132 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Prompt: "why is the ocean blue? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Prompt: "why is the color of dirt brown? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Prompt: rainbowPrompt,
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Prompt: "what is the origin of independence day? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the composition of air?",
|
||||
Prompt: "what is the composition of air? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
{"sunlight"},
|
||||
{"soil", "organic", "earth", "black", "tan"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british"},
|
||||
{"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"},
|
||||
rainbowExpected,
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"},
|
||||
}
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for chat")
|
||||
verify()
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
// think := api.ThinkValue{Value: "low"}
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
// reqs[i].Think = &think
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
@@ -360,6 +741,50 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
gpuPercent := getGPUPercent(ctx, t, client, model)
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
}
|
||||
|
||||
func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model string) int {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if strings.Contains(model, ":") {
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
} else if strings.Contains(m.Name, ":") {
|
||||
if !strings.HasPrefix(m.Name, model+":") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
gpuPercent = 0
|
||||
case m.SizeVRAM == m.Size:
|
||||
gpuPercent = 100
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
return gpuPercent
|
||||
}
|
||||
t.Fatalf("model %s not loaded - actually loaded: %v", model, loaded)
|
||||
return 0
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user