mirror of
https://github.com/ollama/ollama.git
synced 2026-01-10 16:38:45 -05:00
Compare commits
99 Commits
modelfile-
...
rmdisplayl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42f2cc408e | ||
|
|
9446b795b5 | ||
|
|
62f8cda3b3 | ||
|
|
6a1de23175 | ||
|
|
a7b431e743 | ||
|
|
5a25f93522 | ||
|
|
7e33a017c0 | ||
|
|
8b2c10061c | ||
|
|
c5c451ca3b | ||
|
|
2b4ca6cf36 | ||
|
|
ad90b9ab3d | ||
|
|
4340f8eba4 | ||
|
|
4c7db6b7e9 | ||
|
|
c03f0e3c3d | ||
|
|
c5ff443b9f | ||
|
|
01114b4526 | ||
|
|
1524f323a3 | ||
|
|
fccf3eecaa | ||
|
|
c77d45d836 | ||
|
|
5ec12cec6c | ||
|
|
d9578d2bad | ||
|
|
cb8352d6b4 | ||
|
|
fc6558f47f | ||
|
|
9502e5661f | ||
|
|
e1c9a2a00f | ||
|
|
1341ee1b56 | ||
|
|
63efa075a0 | ||
|
|
cb03fc9571 | ||
|
|
a5ec9cfc0f | ||
|
|
be517e491c | ||
|
|
fc8e108642 | ||
|
|
c5d5c4a96c | ||
|
|
dfe330fa1c | ||
|
|
01f77ae25d | ||
|
|
483b81a863 | ||
|
|
36bd967722 | ||
|
|
b0e7d35db8 | ||
|
|
aeb1fb5192 | ||
|
|
a2e60ebcaf | ||
|
|
883ec4d1ef | ||
|
|
4de0126719 | ||
|
|
9768e2dc75 | ||
|
|
08600d5bec | ||
|
|
a624e672d2 | ||
|
|
e4a7e5b2ca | ||
|
|
a0a15cfd5b | ||
|
|
12e923e158 | ||
|
|
cd135317d2 | ||
|
|
4f895d633f | ||
|
|
7d05a6ee8f | ||
|
|
464d817824 | ||
|
|
531324a9be | ||
|
|
6589eb8a8c | ||
|
|
90f071c658 | ||
|
|
a039e383cd | ||
|
|
80163ebcb5 | ||
|
|
a57818d93e | ||
|
|
841adda157 | ||
|
|
0035e31af8 | ||
|
|
c863c6a96d | ||
|
|
1f11b52511 | ||
|
|
526d4eb204 | ||
|
|
0a74cb31d5 | ||
|
|
10ed1b6292 | ||
|
|
4fec5816d6 | ||
|
|
0a0e9f3e0f | ||
|
|
58d95cc9bd | ||
|
|
3b6a9154dd | ||
|
|
d6dd2ff839 | ||
|
|
e57a6ba89f | ||
|
|
12ec2346ef | ||
|
|
1ec0df1069 | ||
|
|
91b3e4d282 | ||
|
|
d338d70492 | ||
|
|
011bb67351 | ||
|
|
d124627202 | ||
|
|
b0a8246a69 | ||
|
|
e6fb39c182 | ||
|
|
e1f1c374ea | ||
|
|
06a1508bfe | ||
|
|
5a5efee46b | ||
|
|
97ae517fbf | ||
|
|
44b813e459 | ||
|
|
539043f5e0 | ||
|
|
dbcace6847 | ||
|
|
c91a4ebcff | ||
|
|
b79c7e4528 | ||
|
|
035b274b70 | ||
|
|
9c6a254945 | ||
|
|
f31f2bedf4 | ||
|
|
756c257553 | ||
|
|
5255d0af8a | ||
|
|
af8a8a6b59 | ||
|
|
461ad25015 | ||
|
|
8838ae787d | ||
|
|
db75402ade | ||
|
|
1e85a140a3 | ||
|
|
c363282fdc | ||
|
|
5b0c48d29e |
2
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
@@ -19,7 +19,7 @@ body:
|
||||
label: What did you expect to see?
|
||||
description: What did you expect to see/happen instead?
|
||||
validations:
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
|
||||
24
.github/workflows/latest.yaml
vendored
Normal file
24
.github/workflows/latest.yaml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: latest
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
jobs:
|
||||
update-latest:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
|
||||
- name: Tag images as latest
|
||||
env:
|
||||
PUSH: "1"
|
||||
shell: bash
|
||||
run: |
|
||||
export "VERSION=${GITHUB_REF_NAME#v}"
|
||||
./scripts/tag_latest.sh
|
||||
90
.github/workflows/release.yaml
vendored
90
.github/workflows/release.yaml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
# Full build of the Mac assets
|
||||
build-darwin:
|
||||
runs-on: macos-latest
|
||||
runs-on: macos-12
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build Darwin
|
||||
env:
|
||||
@@ -38,9 +38,11 @@ jobs:
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
|
||||
run: |
|
||||
./scripts/build_darwin.sh
|
||||
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
@@ -48,7 +50,6 @@ jobs:
|
||||
dist/*arwin*
|
||||
!dist/*-cov
|
||||
|
||||
|
||||
# Windows builds take a long time to both install the dependencies and build, so parallelize
|
||||
# CPU generation step
|
||||
generate-windows-cpu:
|
||||
@@ -85,7 +86,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -99,7 +100,9 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -136,9 +139,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install ROCm"
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -146,7 +149,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: "Verify ROCm"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -160,7 +163,7 @@ jobs:
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
- name: "gather rocm dependencies"
|
||||
- name: 'gather rocm dependencies'
|
||||
run: |
|
||||
$HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
md "dist\deps\bin\rocblas\library"
|
||||
@@ -170,7 +173,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-rocm-deps
|
||||
@@ -211,29 +214,36 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
# TODO - consider replacing this action with a ps1 snippet to install
|
||||
# This actions seems to fail sometimes with "no tools in cache" but a re-run of the failed job clears it
|
||||
# https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
- name: "Install CUDA"
|
||||
uses: Jimver/cuda-toolkit@v0.2.14
|
||||
id: cuda-toolkit
|
||||
with:
|
||||
cuda: '11.3.1'
|
||||
- name: "Verify CUDA"
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
- name: "gather cuda dependencies"
|
||||
- name: 'gather cuda dependencies'
|
||||
run: |
|
||||
$NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0]
|
||||
md "dist\deps"
|
||||
@@ -243,7 +253,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -290,17 +300,17 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -312,8 +322,8 @@ jobs:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build
|
||||
- run: dir llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- run: dir llm/build
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
@@ -329,14 +339,14 @@ jobs:
|
||||
name: dist-windows
|
||||
path: dist/*.exe
|
||||
|
||||
# Linux x86 assets built using the container based build
|
||||
# Linux x86 assets built using the container based build
|
||||
build-linux-amd64:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: amd64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -366,9 +376,9 @@ jobs:
|
||||
environment: release
|
||||
runs-on: linux-arm64
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: arm64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -376,7 +386,7 @@ jobs:
|
||||
- name: Set Version
|
||||
shell: bash
|
||||
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
|
||||
- name: "Install Docker"
|
||||
- name: 'Install Docker'
|
||||
run: |
|
||||
# Add Docker's official GPG key:
|
||||
env
|
||||
@@ -413,7 +423,7 @@ jobs:
|
||||
!dist/*-cov
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
release:
|
||||
needs:
|
||||
- build-darwin
|
||||
- build-windows
|
||||
@@ -424,8 +434,8 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
OLLAMA_SKIP_IMAGE_BUILD: "1"
|
||||
PUSH: "1"
|
||||
OLLAMA_SKIP_IMAGE_BUILD: '1'
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set Version
|
||||
@@ -453,11 +463,11 @@ jobs:
|
||||
with:
|
||||
name: ${{ env.RELEASE_VERSION }}
|
||||
allowUpdates: true
|
||||
artifacts: "dist/*"
|
||||
artifacts: 'dist/*'
|
||||
draft: true
|
||||
prerelease: true
|
||||
omitBodyDuringUpdate: true
|
||||
generateReleaseNotes: true
|
||||
omitDraftDuringUpdate: true
|
||||
omitPrereleaseDuringUpdate: true
|
||||
replacesArtifacts: true
|
||||
replacesArtifacts: true
|
||||
|
||||
188
.github/workflows/test.yaml
vendored
188
.github/workflows/test.yaml
vendored
@@ -5,11 +5,35 @@ on:
|
||||
paths:
|
||||
- '**/*'
|
||||
- '!docs/**'
|
||||
- '!examples/**'
|
||||
- '!README.md'
|
||||
|
||||
jobs:
|
||||
changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
GENERATE: ${{ steps.changes.outputs.GENERATE }}
|
||||
GENERATE_CUDA: ${{ steps.changes.outputs.GENERATE_CUDA }}
|
||||
GENERATE_ROCM: ${{ steps.changes.outputs.GENERATE_ROCM }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- id: changes
|
||||
run: |
|
||||
changed() {
|
||||
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
|
||||
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
||||
}
|
||||
|
||||
{
|
||||
echo GENERATE=$(changed llm/)
|
||||
echo GENERATE_CUDA=$(changed llm/)
|
||||
echo GENERATE_ROCM=$(changed llm/)
|
||||
} >>$GITHUB_OUTPUT
|
||||
|
||||
generate:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -26,26 +50,32 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$gccpath=(get-command gcc).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
||||
echo $env:PATH
|
||||
go generate -x ./...
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
name: "Windows Go Generate"
|
||||
name: 'Windows Go Generate'
|
||||
- run: go generate -x ./...
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
name: "Unix Go Generate"
|
||||
name: 'Unix Go Generate'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
generate-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
cuda-version:
|
||||
@@ -62,7 +92,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -73,12 +103,14 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
generate-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '6.0'
|
||||
- '6.0.2'
|
||||
runs-on: linux
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
@@ -91,7 +123,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -102,7 +134,87 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
# CUDA generation step
|
||||
generate-windows-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -125,24 +237,31 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: false
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
|
||||
touch llm/llama.cpp/ggml-metal.metal
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
shell: bash
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s
|
||||
test:
|
||||
needs: generate
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -156,19 +275,36 @@ jobs:
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: '1'
|
||||
OLLAMA_CPU_TARGET: 'static'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: bash
|
||||
- run: go generate ./...
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ ggml-metal.metal
|
||||
*.exe
|
||||
.idea
|
||||
test_data
|
||||
*.crt
|
||||
*.crt
|
||||
llm/build
|
||||
@@ -15,13 +15,3 @@ linters:
|
||||
- misspell
|
||||
- nilerr
|
||||
- unused
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# exclude the following functions since we don't generally
|
||||
# need to be concerned with the returned errors
|
||||
exclude-functions:
|
||||
- encoding/binary.Read
|
||||
- (*os.File).Seek
|
||||
- (*bufio.Writer).WriteString
|
||||
- (*github.com/spf13/pflag.FlagSet).Set
|
||||
- (*github.com/ollama/ollama/llm.readSeekOffset).Seek
|
||||
|
||||
27
Dockerfile
27
Dockerfile
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
ARG ROCM_VERSION=6.0
|
||||
ARG ROCM_VERSION=6.0.2
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
@@ -61,6 +61,8 @@ ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
@@ -68,28 +70,33 @@ RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=static-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/deps/ ./dist/deps/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
@@ -101,8 +108,8 @@ ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
RUN mkdir -p /go/src/github.com/ollama/ollama/dist/deps/
|
||||
COPY --from=static-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build -trimpath .
|
||||
|
||||
@@ -259,6 +259,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -289,6 +290,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -313,6 +317,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
|
||||
### Package managers
|
||||
|
||||
@@ -371,3 +376,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
// Package api implements the client-side API for code wishing to interact
|
||||
// with the ollama service. The methods of the [Client] type correspond to
|
||||
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
//
|
||||
// The ollama command-line client itself uses this package to interact with
|
||||
// the backend service.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -5,7 +11,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -19,6 +24,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Client encapsulates client state for interacting with the ollama
|
||||
// service. Use [ClientFromEnvironment] to create new Clients.
|
||||
type Client struct {
|
||||
base *url.URL
|
||||
http *http.Client
|
||||
@@ -40,6 +47,15 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return apiError
|
||||
}
|
||||
|
||||
// ClientFromEnvironment creates a new [Client] using configuration from the
|
||||
// environment variable OLLAMA_HOST, which points to the network host and
|
||||
// port on which the ollama service is listenting. The format of this variable
|
||||
// is:
|
||||
//
|
||||
// <scheme>://<host>:<port>
|
||||
//
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
@@ -191,8 +207,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Generate] will stop generating and return this error.
|
||||
type GenerateResponseFunc func(GenerateResponse) error
|
||||
|
||||
// Generate generates a response for a given prompt. The req parameter should
|
||||
// be populated with prompt details. fn is called for each response (there may
|
||||
// be multiple responses, e.g. in case streaming is enabled).
|
||||
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
|
||||
var resp GenerateResponse
|
||||
@@ -204,8 +226,15 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
type ChatResponseFunc func(ChatResponse) error
|
||||
|
||||
// Chat generates the next message in a chat. [ChatRequest] may contain a
|
||||
// sequence of messages which can be used to maintain chat history with a model.
|
||||
// fn is called for each response (there may be multiple responses, e.g. if case
|
||||
// streaming is enabled).
|
||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||
var resp ChatResponse
|
||||
@@ -217,8 +246,14 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
|
||||
})
|
||||
}
|
||||
|
||||
// PullProgressFunc is a function that [Client.Pull] invokes every time there
|
||||
// is progress with a "pull" request sent to the service. If this function
|
||||
// returns an error, [Client.Pull] will stop the process and return this error.
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
// Pull downloads a model from the ollama library. fn is called each time
|
||||
// progress is made on the request and can be used to display a progress bar,
|
||||
// etc.
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
@@ -301,18 +336,7 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
|
||||
var statusError StatusError
|
||||
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
110
api/types.go
110
api/types.go
@@ -33,18 +33,46 @@ func (e StatusError) Error() string {
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
// Model is the model name; it should be a name familiar to Ollama from
|
||||
// the library at https://ollama.com/library
|
||||
Model string `json:"model"`
|
||||
|
||||
// Prompt is the textual prompt to send to the model.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// System overrides the model's default system message/prompt.
|
||||
System string `json:"system"`
|
||||
|
||||
// Template overrides the model's default prompt template.
|
||||
Template string `json:"template"`
|
||||
|
||||
// Context is the context parameter returned from a previous call to
|
||||
// Generate call. It can be used to keep a short conversational memory.
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Raw set to true means that no formatting will be applied to the prompt.
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
@@ -109,21 +137,24 @@ type Options struct {
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
type Runner struct {
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
|
||||
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
@@ -139,10 +170,11 @@ type EmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantization string `json:"quantization,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
@@ -382,18 +414,16 @@ func DefaultOptions() Options {
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
NumCtx: 2048,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -83,6 +84,28 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
io.Copy(logFile, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
// Re-wire context done behavior to attempt a graceful shutdown of the server
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
cmd.Process.Signal(os.Interrupt) //nolint:errcheck
|
||||
tick := time.NewTicker(10 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
// OS agnostic "is it still running"
|
||||
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
|
||||
cmd.Process.Kill() //nolint:errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
@@ -105,7 +128,7 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
done <- code
|
||||
return
|
||||
default:
|
||||
|
||||
@@ -24,10 +24,5 @@ func NewTray() (commontray.OllamaTray, error) {
|
||||
return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err)
|
||||
}
|
||||
|
||||
tray, err := InitPlatformTray(icon, updateIcon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tray, nil
|
||||
return InitPlatformTray(icon, updateIcon)
|
||||
}
|
||||
|
||||
28
cmd/cmd.go
28
cmd/cmd.go
@@ -194,7 +194,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -213,7 +215,10 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
}
|
||||
bin.Seek(0, io.SeekStart)
|
||||
|
||||
if _, err := bin.Seek(0, io.SeekStart); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
@@ -223,6 +228,14 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if os.Getenv("OLLAMA_MODELS") != "" {
|
||||
return errors.New("OLLAMA_MODELS must only be set for 'ollama serve'")
|
||||
}
|
||||
|
||||
if err := checkServerHeartbeat(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -932,6 +945,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -948,11 +962,10 @@ func NewCLI() *cobra.Command {
|
||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: RunHandler,
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
@@ -973,6 +986,7 @@ Environment Variables:
|
||||
OLLAMA_ORIGINS A comma separated list of allowed origins.
|
||||
OLLAMA_MODELS The path to the models directory (default is "~/.ollama/models")
|
||||
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default is "5m")
|
||||
OLLAMA_DEBUG Set to 1 to enable additional debug logging
|
||||
`)
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
|
||||
@@ -295,10 +295,14 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
@@ -30,9 +32,17 @@ type Params struct {
|
||||
AttentionHeads int `json:"num_attention_heads"` // n_head
|
||||
KeyValHeads int `json:"num_key_value_heads"`
|
||||
NormEPS float64 `json:"rms_norm_eps"`
|
||||
RopeFreqBase float64 `json:"rope_theta"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
HeadDimension int `json:"head_dim"`
|
||||
PaddingTokenID int `json:"pad_token_id"`
|
||||
|
||||
ByteOrder
|
||||
}
|
||||
|
||||
type ByteOrder interface {
|
||||
binary.ByteOrder
|
||||
binary.AppendByteOrder
|
||||
}
|
||||
|
||||
type MetaData struct {
|
||||
@@ -41,27 +51,43 @@ type MetaData struct {
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
Path string
|
||||
Name string
|
||||
Params *Params
|
||||
Vocab *Vocab
|
||||
Tensors []llm.Tensor
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
binary.Read(f, binary.LittleEndian, &jsonSize)
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
@@ -78,7 +104,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data MetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
@@ -100,7 +126,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
ggufName, err := GetTensorName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
@@ -109,14 +135,22 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
FileName: fn,
|
||||
OffsetPadding: 8 + jsonSize,
|
||||
FileOffsets: []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])},
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("%v", t))
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
@@ -124,21 +158,21 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func GetSafeTensors(dirpath string) ([]llm.Tensor, error) {
|
||||
func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = ReadSafeTensors(f, offset)
|
||||
t, offset, err = ReadSafeTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, err
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
@@ -160,6 +194,7 @@ func GetParams(dirpath string) (*Params, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
@@ -171,7 +206,7 @@ type Vocab struct {
|
||||
Types []int32
|
||||
}
|
||||
|
||||
func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
|
||||
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
|
||||
if err != nil {
|
||||
@@ -196,6 +231,14 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
v.Tokens = append(v.Tokens, p.GetPiece())
|
||||
v.Scores = append(v.Scores, p.GetScore())
|
||||
t := p.GetType()
|
||||
switch t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
|
||||
case sentencepiece.ModelProto_SentencePiece_CONTROL:
|
||||
case sentencepiece.ModelProto_SentencePiece_UNUSED:
|
||||
case sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
default:
|
||||
t = sentencepiece.ModelProto_SentencePiece_NORMAL
|
||||
}
|
||||
v.Types = append(v.Types, int32(t))
|
||||
}
|
||||
|
||||
@@ -243,6 +286,16 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
}
|
||||
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
|
||||
|
||||
if vocabSize > len(v.Tokens) {
|
||||
missingTokens := vocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
v.Scores = append(v.Scores, -1)
|
||||
v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@@ -279,53 +332,102 @@ func GetTensorName(n string) (string, error) {
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
|
||||
c := llm.ContainerGGUF{
|
||||
ByteOrder: binary.LittleEndian,
|
||||
}
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
m := llm.NewGGUFModel(&c)
|
||||
m.Tensors = tensors
|
||||
m.KV["general.architecture"] = "llama"
|
||||
m.KV["general.name"] = name
|
||||
m.KV["llama.context_length"] = uint32(params.ContextSize)
|
||||
m.KV["llama.embedding_length"] = uint32(params.HiddenSize)
|
||||
m.KV["llama.block_count"] = uint32(params.HiddenLayers)
|
||||
m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize)
|
||||
m.KV["llama.rope.dimension_count"] = uint32(128)
|
||||
m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads)
|
||||
m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
|
||||
m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
|
||||
m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase)
|
||||
m.KV["general.file_type"] = uint32(1)
|
||||
m.KV["tokenizer.ggml.model"] = "llama"
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
m.KV["tokenizer.ggml.tokens"] = vocab.Tokens
|
||||
m.KV["tokenizer.ggml.scores"] = vocab.Scores
|
||||
m.KV["tokenizer.ggml.token_type"] = vocab.Types
|
||||
filename string
|
||||
|
||||
m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
|
||||
m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
|
||||
m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0)
|
||||
m.KV["tokenizer.ggml.add_bos_token"] = true
|
||||
m.KV["tokenizer.ggml.add_eos_token"] = false
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
// llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer
|
||||
// m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme
|
||||
|
||||
c.V3.NumTensor = uint64(len(tensors))
|
||||
c.V3.NumKV = uint64(len(m.KV))
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = m.Encode(f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
|
||||
136
convert/gemma.go
Normal file
136
convert/gemma.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type GemmaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
|
||||
|
||||
data := make([]byte, r.end-r.start)
|
||||
if err := binary.Read(f, r.bo, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
var err error
|
||||
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, vectorSize)
|
||||
|
||||
var err error
|
||||
n, err = n.Add(ones)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
newN, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
var fullTensor []float32
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
for _, l := range t {
|
||||
if strings.HasSuffix(l.Name, "norm.weight") {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = gemmaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "gemma",
|
||||
"general.name": m.Name,
|
||||
"gemma.context_length": uint32(m.Params.ContextSize),
|
||||
"gemma.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"gemma.block_count": uint32(m.Params.HiddenLayers),
|
||||
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
|
||||
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(3),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
173
convert/mistral.go
Normal file
173
convert/mistral.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type MistralModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
layerSize := r.end - r.start
|
||||
|
||||
var err error
|
||||
tData := make([]uint16, layerSize/2)
|
||||
if err = binary.Read(f, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
tData, err = repack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for _, n := range tData {
|
||||
buf = r.bo.AppendUint16(buf, n)
|
||||
}
|
||||
|
||||
tempBuf := make([]uint16, len(tData))
|
||||
tDataF32 := bfloat16.DecodeFloat32(buf)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = mistralLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
@@ -394,7 +394,6 @@ Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
|
||||
@@ -71,9 +71,12 @@ More examples are available in the [examples directory](../examples).
|
||||
|
||||
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
|
||||
|
||||
- Option 1: view a model's data:
|
||||
1. Go to a particular model page (e.g. https://ollama.com/library/llama2)
|
||||
2. There is a table that displays the model's different components
|
||||
- Option 1: view a details page from a model's tags page:
|
||||
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
|
||||
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
|
||||
3. Scroll down to "Layers"
|
||||
- Note: if the [`FROM` instruction](#from-required) is not present,
|
||||
it means the model was created from a local file
|
||||
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
|
||||
|
||||
```bash
|
||||
@@ -212,6 +215,7 @@ MESSAGE <role> <message>
|
||||
| user | An example message of what the user could have asked. |
|
||||
| assistant | An example message of how the model should respond. |
|
||||
|
||||
|
||||
#### Example conversation
|
||||
|
||||
```modelfile
|
||||
@@ -223,6 +227,7 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
@@ -76,3 +76,10 @@ install script which version to install.
|
||||
```sh
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its
|
||||
temporary executable files, you can specify an alternate location by setting
|
||||
OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example
|
||||
OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
40
examples/go-generate-streaming/main.go
Normal file
40
examples/go-generate-streaming/main.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// By default, GenerateRequest is streaming.
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
37
examples/go-generate/main.go
Normal file
37
examples/go-generate/main.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
|
||||
// set streaming to false
|
||||
Stream: new(bool),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
fmt.Println(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Byte = 1
|
||||
Byte = 1
|
||||
|
||||
KiloByte = Byte * 1000
|
||||
MegaByte = KiloByte * 1000
|
||||
GigaByte = MegaByte * 1000
|
||||
TeraByte = GigaByte * 1000
|
||||
|
||||
KibiByte = Byte * 1024
|
||||
MebiByte = KibiByte * 1024
|
||||
)
|
||||
|
||||
func HumanBytes(b int64) string {
|
||||
@@ -45,3 +49,14 @@ func HumanBytes(b int64) string {
|
||||
return fmt.Sprintf("%d %s", int(value), unit)
|
||||
}
|
||||
}
|
||||
|
||||
func HumanBytes2(b uint64) string {
|
||||
switch {
|
||||
case b >= MebiByte:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||
case b >= KibiByte:
|
||||
return fmt.Sprintf("%.1f KiB", float64(b)/KibiByte)
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -9,7 +9,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
|
||||
@@ -100,6 +100,8 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
updateLibPath(libDir)
|
||||
|
||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||
if gfxOverride == "" {
|
||||
supported, err := GetSupportedGFX(libDir)
|
||||
@@ -113,7 +115,7 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
skip[i] = struct{}{}
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
||||
@@ -143,6 +145,21 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func updateLibPath(libDir string) {
|
||||
ldPaths := []string{}
|
||||
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
ldPaths = strings.Split(val, ":")
|
||||
}
|
||||
for _, d := range ldPaths {
|
||||
if d == libDir {
|
||||
return
|
||||
}
|
||||
}
|
||||
val := strings.Join(append(ldPaths, libDir), ":")
|
||||
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
|
||||
os.Setenv("LD_LIBRARY_PATH", val)
|
||||
}
|
||||
|
||||
// Walk the sysfs nodes for the available GPUs and gather information from them
|
||||
// skipping over any devices in the skip map
|
||||
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,11 +22,20 @@ var (
|
||||
func PayloadsDir() (string, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
var err error
|
||||
if payloadsDir == "" {
|
||||
cleanupTmpDirs()
|
||||
tmpDir, err := os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
tmpDir := os.Getenv("OLLAMA_TMPDIR")
|
||||
if tmpDir == "" {
|
||||
tmpDir, err = os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.MkdirAll(tmpDir, 0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir %s: %w", tmpDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track our pid so we can clean up orphaned tmpdirs
|
||||
@@ -84,7 +94,12 @@ func Cleanup() {
|
||||
slog.Debug("cleaning up", "dir", tmpDir)
|
||||
err := os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
// On windows, if we remove too quickly the llama.dll may still be in-use and fail to remove
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
err = os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
55
gpu/gpu.go
55
gpu/gpu.go
@@ -20,6 +20,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
type handles struct {
|
||||
@@ -27,8 +29,12 @@ type handles struct {
|
||||
cudart *C.cudart_handle_t
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
)
|
||||
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
@@ -78,11 +84,11 @@ var CudartWindowsGlobs = []string{
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initGPUHandles() {
|
||||
func initGPUHandles() *handles {
|
||||
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
gpuHandles = &handles{nil, nil}
|
||||
gpuHandles := &handles{nil, nil}
|
||||
var nvmlMgmtName string
|
||||
var nvmlMgmtPatterns []string
|
||||
var cudartMgmtName string
|
||||
@@ -109,7 +115,7 @@ func initGPUHandles() {
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||
default:
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
slog.Info("Detecting GPU type")
|
||||
@@ -119,7 +125,7 @@ func initGPUHandles() {
|
||||
if cudart != nil {
|
||||
slog.Info("Nvidia GPU detected via cudart")
|
||||
gpuHandles.cudart = cudart
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,10 +136,10 @@ func initGPUHandles() {
|
||||
if nvml != nil {
|
||||
slog.Info("Nvidia GPU detected via nvidia-ml")
|
||||
gpuHandles.nvml = nvml
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
@@ -141,9 +147,16 @@ func GetGPUInfo() GpuInfo {
|
||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
if gpuHandles == nil {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
gpuHandles := initGPUHandles()
|
||||
defer func() {
|
||||
if gpuHandles.nvml != nil {
|
||||
C.nvml_release(*gpuHandles.nvml)
|
||||
}
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_release(*gpuHandles.cudart)
|
||||
}
|
||||
}()
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
@@ -168,6 +181,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
resp.MinimumMemory = cudaMinimumMemory
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
@@ -187,6 +201,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
resp.MinimumMemory = cudaMinimumMemory
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
@@ -194,6 +209,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else {
|
||||
AMDGetGPUInfo(&resp)
|
||||
if resp.Library != "" {
|
||||
resp.MinimumMemory = rocmMinimumMemory
|
||||
return resp
|
||||
}
|
||||
}
|
||||
@@ -227,7 +243,7 @@ func getCPUMem() (memInfo, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckVRAM() (int64, error) {
|
||||
func CheckVRAM() (uint64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -235,24 +251,11 @@ func CheckVRAM() (int64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
return uint64(avail), nil
|
||||
}
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
overhead := gpuInfo.FreeMemory / 10
|
||||
gpus := uint64(gpuInfo.DeviceCount)
|
||||
if overhead < gpus*1024*1024*1024 {
|
||||
overhead = gpus * 1024 * 1024 * 1024
|
||||
}
|
||||
// Assigning full reported free memory for Tegras due to OS controlled caching.
|
||||
if CudaTegra != "" {
|
||||
// Setting overhead for non-Tegra devices
|
||||
overhead = 0
|
||||
}
|
||||
avail := int64(gpuInfo.FreeMemory - overhead)
|
||||
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
|
||||
return avail, nil
|
||||
return gpuInfo.FreeMemory, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (int64, error) {
|
||||
func CheckVRAM() (uint64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -25,15 +25,14 @@ func CheckVRAM() (int64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
return uint64(avail), nil
|
||||
}
|
||||
|
||||
if runtime.GOARCH == "amd64" {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM())
|
||||
return recommendedMaxVRAM, nil
|
||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
|
||||
@@ -62,6 +62,10 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
@@ -187,4 +191,10 @@ void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *r
|
||||
}
|
||||
}
|
||||
|
||||
void cudart_release(cudart_handle_t h) {
|
||||
LOG(h.verbose, "releasing cudart library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -7,6 +7,7 @@
|
||||
typedef enum cudartReturn_enum {
|
||||
CUDART_SUCCESS = 0,
|
||||
CUDART_UNSUPPORTED = 1,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} cudartReturn_t;
|
||||
|
||||
@@ -54,6 +55,7 @@ typedef struct cudart_compute_capability {
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
|
||||
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
|
||||
void cudart_release(cudart_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_CUDART_H__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -211,4 +211,11 @@ void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void nvml_release(nvml_handle_t h) {
|
||||
LOG(h.verbose, "releasing nvml library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -51,6 +51,7 @@ typedef struct nvml_compute_capability {
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
||||
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
|
||||
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
|
||||
void nvml_release(nvml_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVML_H__
|
||||
#endif // __APPLE__
|
||||
@@ -14,6 +14,9 @@ type GpuInfo struct {
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
||||
|
||||
@@ -24,5 +24,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
|
||||
}
|
||||
|
||||
29
integration/context_test.go
Normal file
29
integration/context_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
||||
}
|
||||
@@ -15,10 +15,6 @@ import (
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
|
||||
@@ -126,7 +126,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
||||
slog.Debug("checking status of model", "model", modelName)
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
requestJSON, err := json.Marshal(showReq)
|
||||
if err != nil {
|
||||
@@ -174,36 +174,51 @@ func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoin
|
||||
return nil
|
||||
}
|
||||
|
||||
var serverProcMutex sync.Mutex
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||
|
||||
// TODO maybe stuff in an init routine?
|
||||
lifecycle.InitLogging()
|
||||
|
||||
requestJSON, err := json.Marshal(genReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
// TODO
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#elif _WIN32
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
#define LOAD_ERR() ({\
|
||||
LPSTR messageBuffer = NULL; \
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \
|
||||
char *resp = strdup(messageBuffer); \
|
||||
LocalFree(messageBuffer); \
|
||||
resp; \
|
||||
})
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err) {
|
||||
int i = 0;
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"llama_server_init", (void *)&s->llama_server_init},
|
||||
{"llama_server_start", (void *)&s->llama_server_start},
|
||||
{"llama_server_stop", (void *)&s->llama_server_stop},
|
||||
{"llama_server_completion", (void *)&s->llama_server_completion},
|
||||
{"llama_server_completion_next_result",
|
||||
(void *)&s->llama_server_completion_next_result},
|
||||
{"llama_server_completion_cancel",
|
||||
(void *)&s->llama_server_completion_cancel},
|
||||
{"llama_server_release_task_result",
|
||||
(void *)&s->llama_server_release_task_result},
|
||||
{"llama_server_tokenize", (void *)&s->llama_server_tokenize},
|
||||
{"llama_server_detokenize", (void *)&s->llama_server_detokenize},
|
||||
{"llama_server_embedding", (void *)&s->llama_server_embedding},
|
||||
{"llama_server_release_json_resp",
|
||||
(void *)&s->llama_server_release_json_resp},
|
||||
{"", NULL},
|
||||
};
|
||||
|
||||
printf("loading library %s\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unable to load dynamic server library: %s", msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].p != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(s->handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(s->handle);
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len, "symbol lookup for %s failed: %s",
|
||||
l[i].s, msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_init(sparams, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_start(struct dynamic_llama_server s) {
|
||||
s.llama_server_start();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
|
||||
s.llama_server_stop();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp) {
|
||||
s.llama_server_completion(json_req, resp);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result) {
|
||||
s.llama_server_completion_next_result(task_id, result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_cancel(
|
||||
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
|
||||
s.llama_server_completion_cancel(task_id, err);
|
||||
}
|
||||
inline void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result) {
|
||||
s.llama_server_release_task_result(result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_tokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_detokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_embedding(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_release_json_resp(
|
||||
struct dynamic_llama_server s, char **json_resp) {
|
||||
s.llama_server_release_json_resp(json_resp);
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
package llm
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
|
||||
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||
#cgo linux CFLAGS: -D_GNU_SOURCE
|
||||
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
|
||||
#cgo linux windows LDFLAGS: -lpthread
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
type dynExtServer struct {
|
||||
s C.struct_dynamic_llama_server
|
||||
options api.Options
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var mutex sync.Mutex
|
||||
|
||||
func newExtServerResp(len C.size_t) C.ext_server_resp_t {
|
||||
var resp C.ext_server_resp_t
|
||||
resp.msg_len = len
|
||||
bytes := make([]byte, len)
|
||||
resp.msg = (*C.char)(C.CBytes(bytes))
|
||||
return resp
|
||||
}
|
||||
|
||||
func freeExtServerResp(resp C.ext_server_resp_t) {
|
||||
if resp.msg_len == 0 {
|
||||
return
|
||||
}
|
||||
C.free(unsafe.Pointer(resp.msg))
|
||||
}
|
||||
|
||||
func extServerResponseToErr(resp C.ext_server_resp_t) error {
|
||||
return fmt.Errorf(C.GoString(resp.msg))
|
||||
}
|
||||
|
||||
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if !mutex.TryLock() {
|
||||
slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||
mutex.Lock()
|
||||
}
|
||||
gpu.UpdatePath(filepath.Dir(library))
|
||||
libPath := C.CString(library)
|
||||
defer C.free(unsafe.Pointer(libPath))
|
||||
resp := newExtServerResp(512)
|
||||
defer freeExtServerResp(resp)
|
||||
var srv C.struct_dynamic_llama_server
|
||||
C.dyn_init(libPath, &srv, &resp)
|
||||
if resp.id < 0 {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
|
||||
}
|
||||
llm := dynExtServer{
|
||||
s: srv,
|
||||
options: opts,
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Loading Dynamic llm server: %s", library))
|
||||
|
||||
var sparams C.ext_server_params_t
|
||||
sparams.model = C.CString(model)
|
||||
defer C.free(unsafe.Pointer(sparams.model))
|
||||
|
||||
sparams.embedding = true
|
||||
sparams.n_ctx = C.uint(opts.NumCtx)
|
||||
sparams.n_batch = C.uint(opts.NumBatch)
|
||||
sparams.n_gpu_layers = C.int(opts.NumGPU)
|
||||
sparams.main_gpu = C.int(opts.MainGPU)
|
||||
sparams.n_parallel = 1 // TODO - wire up concurrency
|
||||
|
||||
// Always use the value encoded in the model
|
||||
sparams.rope_freq_base = 0.0
|
||||
sparams.rope_freq_scale = 0.0
|
||||
sparams.memory_f16 = C.bool(opts.F16KV)
|
||||
sparams.use_mlock = C.bool(opts.UseMLock)
|
||||
sparams.use_mmap = C.bool(opts.UseMMap)
|
||||
|
||||
if opts.UseNUMA {
|
||||
sparams.numa = C.int(1)
|
||||
} else {
|
||||
sparams.numa = C.int(0)
|
||||
}
|
||||
|
||||
sparams.lora_adapters = nil
|
||||
for i := 0; i < len(adapters); i++ {
|
||||
la := (*C.ext_server_lora_adapter_t)(C.malloc(C.sizeof_ext_server_lora_adapter_t))
|
||||
defer C.free(unsafe.Pointer(la))
|
||||
la.adapter = C.CString(adapters[i])
|
||||
defer C.free(unsafe.Pointer(la.adapter))
|
||||
la.scale = C.float(1.0) // TODO expose scale/weights up through ollama UX
|
||||
la.next = nil
|
||||
if i == 0 {
|
||||
sparams.lora_adapters = la
|
||||
} else {
|
||||
tmp := sparams.lora_adapters
|
||||
for ; tmp.next != nil; tmp = tmp.next {
|
||||
}
|
||||
tmp.next = la
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
sparams.mmproj = C.CString(projectors[0])
|
||||
defer C.free(unsafe.Pointer(sparams.mmproj))
|
||||
} else {
|
||||
sparams.mmproj = nil
|
||||
}
|
||||
|
||||
sparams.n_threads = C.uint(opts.NumThread)
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
sparams.verbose_logging = C.bool(true)
|
||||
} else {
|
||||
sparams.verbose_logging = C.bool(false)
|
||||
}
|
||||
|
||||
slog.Info("Initializing llama server")
|
||||
slog.Debug(fmt.Sprintf("server params: %+v", sparams))
|
||||
initResp := newExtServerResp(512)
|
||||
defer freeExtServerResp(initResp)
|
||||
C.dyn_llama_server_init(llm.s, &sparams, &initResp)
|
||||
if initResp.id < 0 {
|
||||
mutex.Unlock()
|
||||
err := extServerResponseToErr(initResp)
|
||||
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Starting llama main loop")
|
||||
C.dyn_llama_server_start(llm.s)
|
||||
return &llm, nil
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
|
||||
if len(predict.Images) > 0 {
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||
}
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": predict.Options.NumPredict,
|
||||
"n_keep": predict.Options.NumKeep,
|
||||
"temperature": predict.Options.Temperature,
|
||||
"top_k": predict.Options.TopK,
|
||||
"top_p": predict.Options.TopP,
|
||||
"tfs_z": predict.Options.TFSZ,
|
||||
"typical_p": predict.Options.TypicalP,
|
||||
"repeat_last_n": predict.Options.RepeatLastN,
|
||||
"repeat_penalty": predict.Options.RepeatPenalty,
|
||||
"presence_penalty": predict.Options.PresencePenalty,
|
||||
"frequency_penalty": predict.Options.FrequencyPenalty,
|
||||
"mirostat": predict.Options.Mirostat,
|
||||
"mirostat_tau": predict.Options.MirostatTau,
|
||||
"mirostat_eta": predict.Options.MirostatEta,
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": predict.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if predict.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(predict.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
}
|
||||
|
||||
retryDelay := 100 * time.Microsecond
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
retryDelay *= 2 // exponential backoff
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(buffer.String())
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
|
||||
C.dyn_llama_server_completion(llm.s, req, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
}
|
||||
|
||||
retryNeeded := false
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return cancelCompletion(llm, resp)
|
||||
default:
|
||||
var result C.ext_server_task_result_t
|
||||
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
||||
json_resp := C.GoString(result.json_resp)
|
||||
C.dyn_llama_server_release_task_result(llm.s, &result)
|
||||
|
||||
var p prediction
|
||||
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
|
||||
} else {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
|
||||
retryNeeded = true
|
||||
// task will already be canceled
|
||||
break out
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(p.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(p.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return cancelCompletion(llm, resp)
|
||||
}
|
||||
|
||||
if p.Content != "" {
|
||||
fn(PredictResult{
|
||||
Content: p.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if p.Stop || bool(result.stop) {
|
||||
fn(PredictResult{
|
||||
Done: true,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if !retryNeeded {
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err2)
|
||||
}
|
||||
|
||||
return encoded.Tokens, err
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
if len(tokens) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return "", extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err2)
|
||||
}
|
||||
|
||||
return decoded.Content, err
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Close() {
|
||||
C.dyn_llama_server_stop(llm.s)
|
||||
mutex.Unlock()
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "ext_server.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
struct dynamic_llama_server {
|
||||
void *handle;
|
||||
void (*llama_server_init)(ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_start)();
|
||||
void (*llama_server_stop)();
|
||||
void (*llama_server_completion)(const char *json_req,
|
||||
ext_server_resp_t *resp);
|
||||
void (*llama_server_completion_next_result)(const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
void (*llama_server_completion_cancel)(const int task_id,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_release_task_result)(ext_server_task_result_t *result);
|
||||
void (*llama_server_tokenize)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_detokenize)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_embedding)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_release_json_resp)(char **json_resp);
|
||||
};
|
||||
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
// No good way to call C function pointers from Go so inline the indirection
|
||||
void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_start(struct dynamic_llama_server s);
|
||||
|
||||
void dyn_llama_server_stop(struct dynamic_llama_server s);
|
||||
|
||||
void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp);
|
||||
|
||||
void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
|
||||
void dyn_llama_server_completion_cancel(struct dynamic_llama_server s,
|
||||
const int task_id,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result);
|
||||
|
||||
void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void dyn_llama_server_release_json_resp(struct dynamic_llama_server s,
|
||||
char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
27
llm/ext_server/CMakeLists.txt
vendored
27
llm/ext_server/CMakeLists.txt
vendored
@@ -1,21 +1,14 @@
|
||||
|
||||
set(TARGET ext_server)
|
||||
set(TARGET ollama_llama_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (WIN32)
|
||||
add_library(${TARGET} SHARED ext_server.cpp ../llama.cpp/llama.cpp)
|
||||
else()
|
||||
add_library(${TARGET} STATIC ext_server.cpp ../llama.cpp/llama.cpp)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml llava common )
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>)
|
||||
install(TARGETS ext_server LIBRARY)
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE nvml)
|
||||
endif()
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
18
llm/ext_server/README.md
vendored
18
llm/ext_server/README.md
vendored
@@ -1,18 +0,0 @@
|
||||
# Extern C Server
|
||||
|
||||
This directory contains a thin facade we layer on top of the Llama.cpp server to
|
||||
expose `extern C` interfaces to access the functionality through direct API
|
||||
calls in-process. The llama.cpp code uses compile time macros to configure GPU
|
||||
type along with other settings. During the `go generate ./...` execution, the
|
||||
build will generate one or more copies of the llama.cpp `extern C` server based
|
||||
on what GPU libraries are detected to support multiple GPU types as well as CPU
|
||||
only support. The Ollama go build then embeds these different servers to support
|
||||
different GPUs and settings at runtime.
|
||||
|
||||
If you are making changes to the code in this directory, make sure to disable
|
||||
caching during your go build to ensure you pick up your changes. A typical
|
||||
iteration cycle from the top of the source tree looks like:
|
||||
|
||||
```
|
||||
go generate ./... && go build -a .
|
||||
```
|
||||
377
llm/ext_server/ext_server.cpp
vendored
377
llm/ext_server/ext_server.cpp
vendored
@@ -1,377 +0,0 @@
|
||||
#include "ext_server.h"
|
||||
#include <atomic>
|
||||
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
||||
// Low level API access to verify GPU access
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::thread ext_server_thread;
|
||||
bool shutting_down = false;
|
||||
std::atomic_int recv_counter;
|
||||
|
||||
// RAII wrapper for tracking in-flight recv calls
|
||||
class atomicRecv {
|
||||
public:
|
||||
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
|
||||
++this->atomic;
|
||||
}
|
||||
~atomicRecv() {
|
||||
--this->atomic;
|
||||
}
|
||||
private:
|
||||
std::atomic<int> &atomic;
|
||||
};
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
recv_counter = 0;
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
server_verbose = true;
|
||||
log_disable();
|
||||
}
|
||||
|
||||
LOG_TEE("system info: %s\n", llama_print_system_info());
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama = new llama_server_context;
|
||||
gpt_params params;
|
||||
params.n_ctx = sparams->n_ctx;
|
||||
params.n_batch = sparams->n_batch;
|
||||
if (sparams->n_threads > 0) {
|
||||
params.n_threads = sparams->n_threads;
|
||||
}
|
||||
params.n_parallel = sparams->n_parallel;
|
||||
params.rope_freq_base = sparams->rope_freq_base;
|
||||
params.rope_freq_scale = sparams->rope_freq_scale;
|
||||
|
||||
if (sparams->memory_f16) {
|
||||
params.cache_type_k = "f16";
|
||||
params.cache_type_v = "f16";
|
||||
} else {
|
||||
params.cache_type_k = "f32";
|
||||
params.cache_type_v = "f32";
|
||||
}
|
||||
|
||||
params.n_gpu_layers = sparams->n_gpu_layers;
|
||||
params.main_gpu = sparams->main_gpu;
|
||||
params.use_mlock = sparams->use_mlock;
|
||||
params.use_mmap = sparams->use_mmap;
|
||||
params.numa = (ggml_numa_strategy)sparams->numa;
|
||||
params.embedding = sparams->embedding;
|
||||
if (sparams->model != NULL) {
|
||||
params.model = sparams->model;
|
||||
}
|
||||
|
||||
if (sparams->lora_adapters != NULL) {
|
||||
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
|
||||
la = la->next) {
|
||||
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
}
|
||||
|
||||
params.use_mmap = false;
|
||||
}
|
||||
|
||||
if (sparams->mmproj != NULL) {
|
||||
params.mmproj = std::string(sparams->mmproj);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
|
||||
LOG_TEE("Performing pre-initialization of GPU\n");
|
||||
int id;
|
||||
cudaError_t cudaErr = cudaGetDevice(&id);
|
||||
if (cudaErr != cudaSuccess) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
if (!llama->load_model(params)) {
|
||||
// an error occurred that was not thrown
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "error loading model %s", params.model.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
llama->initialize();
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception initializing llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
llama->queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_run_slots(std::bind(
|
||||
&llama_server_context::update_slots, llama));
|
||||
llama->queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama->queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama->queue_tasks.start_loop();
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
LOG_TEE("caught unknown exception in llama server main loop\n");
|
||||
}
|
||||
LOG_TEE("\nllama server shutting down\n");
|
||||
llama_backend_free();
|
||||
});
|
||||
}
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// Shutdown any in-flight requests and block incoming requests.
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
shutting_down = true;
|
||||
|
||||
while (recv_counter.load() > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
}
|
||||
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
LOG_TEE("llama server shutdown complete\n");
|
||||
shutting_down = false;
|
||||
}
|
||||
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
assert(llama != NULL && json_req != NULL && resp != NULL);
|
||||
resp->id = -1;
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
llama->request_completion(resp->id, data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
snprintf(resp->msg, resp->msg_len, "Unknown exception during completion");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *resp) {
|
||||
assert(llama != NULL && resp != NULL);
|
||||
resp->id = -1;
|
||||
resp->stop = false;
|
||||
resp->error = false;
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
LOG_TEE("next result cancel on error\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (result.stop) {
|
||||
LOG_TEE("next result cancel on stop\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (shutting_down) {
|
||||
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
resp->stop = true;
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"exception " + std::string(e.what()) + "\"}";
|
||||
LOG_TEE("llama server completion exception %s\n", e.what());
|
||||
} catch (...) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"Unknown exception during completion\"}";
|
||||
LOG_TEE("llama server completion unknown exception\n");
|
||||
}
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
resp->json_resp = new char[size];
|
||||
snprintf(resp->json_resp, size, "%s", result_json.c_str());
|
||||
}
|
||||
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result) {
|
||||
if (result == NULL || result->json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] result->json_resp;
|
||||
}
|
||||
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
assert(llama != NULL && err != NULL);
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception completion cancel in llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0) {
|
||||
tokens = llama->tokenize(body["content"], false);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during tokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_release_json_resp(char **json_resp) {
|
||||
if (json_resp == NULL || *json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] *json_resp;
|
||||
}
|
||||
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
const std::vector<llama_token> tokens = body["tokens"];
|
||||
content = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
const json data = format_detokenized_response(content);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during detokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
json prompt;
|
||||
if (body.count("content") != 0) {
|
||||
prompt = body["content"];
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during embedding");
|
||||
}
|
||||
}
|
||||
95
llm/ext_server/ext_server.h
vendored
95
llm/ext_server/ext_server.h
vendored
@@ -1,95 +0,0 @@
|
||||
#if defined(LLAMA_SERVER_LIBRARY)
|
||||
#ifndef LLAMA_SERVER_H
|
||||
#define LLAMA_SERVER_H
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int __main(int argc, char **argv);
|
||||
|
||||
// This exposes extern C entrypoints into the llama_server
|
||||
// To enable the server compile with LLAMA_SERVER_LIBRARY
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef struct ext_server_resp {
|
||||
int id; // < 0 on error
|
||||
size_t msg_len; // caller must allocate msg and set msg_len
|
||||
char *msg;
|
||||
} ext_server_resp_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_lora_adapter {
|
||||
char *adapter;
|
||||
float scale;
|
||||
struct ext_server_lora_adapter *next;
|
||||
} ext_server_lora_adapter_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_params {
|
||||
char *model;
|
||||
uint32_t n_ctx; // token context window, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_parallel; // number of parallel sequences to decodewra
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
bool memory_f16; // use f16 instead of f32 for memory kv
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool use_mmap; // use mmap if possible
|
||||
int numa; // attempt optimizations that help on some NUMA systems
|
||||
bool embedding; // get only sentence embedding
|
||||
ext_server_lora_adapter_t *lora_adapters;
|
||||
char *mmproj;
|
||||
bool verbose_logging; // Enable verbose logging of the server
|
||||
} ext_server_params_t;
|
||||
|
||||
typedef struct ext_server_task_result {
|
||||
int id;
|
||||
bool stop;
|
||||
bool error;
|
||||
char *json_resp; // null terminated, memory managed by ext_server
|
||||
} ext_server_task_result_t;
|
||||
|
||||
// Initialize the server once per process
|
||||
// err->id = 0 for success and err->msg[0] = NULL
|
||||
// err->id != 0 for failure, and err->msg contains error message
|
||||
void llama_server_init(ext_server_params_t *sparams, ext_server_resp_t *err);
|
||||
|
||||
// Run the main loop, called once per init
|
||||
void llama_server_start();
|
||||
// Stop the main loop and free up resources allocated in init and start. Init
|
||||
// must be called again to reuse
|
||||
void llama_server_stop();
|
||||
|
||||
// json_req null terminated string, memory managed by caller
|
||||
// resp->id >= 0 on success (task ID)
|
||||
// resp->id < 0 on error, and resp->msg contains error message
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp);
|
||||
|
||||
// Caller must call llama_server_release_task_result to free resp->json_resp
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err);
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result);
|
||||
|
||||
// Caller must call llama_server_releaes_json_resp to free json_resp if err.id <
|
||||
// 0
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_release_json_resp(char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif // LLAMA_SERVER_LIBRARY
|
||||
6
llm/ext_server/server.cpp
vendored
6
llm/ext_server/server.cpp
vendored
@@ -1007,13 +1007,15 @@ struct llama_server_context
|
||||
slot.n_sent_text += result.text_to_send.size();
|
||||
// add the token to slot queue and cache
|
||||
}
|
||||
slot.add_token_string(result);
|
||||
|
||||
if (slot.params.stream)
|
||||
{
|
||||
send_partial_response(slot, result);
|
||||
}
|
||||
}
|
||||
|
||||
slot.add_token_string(result);
|
||||
|
||||
if (incomplete)
|
||||
{
|
||||
slot.has_next_token = true;
|
||||
@@ -2768,7 +2770,7 @@ inline void signal_handler(int signal) {
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
int _main(int argc, char **argv)
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
log_disable();
|
||||
|
||||
@@ -14,7 +14,7 @@ init_vars() {
|
||||
|
||||
LLAMACPP_DIR=../llama.cpp
|
||||
CMAKE_DEFS=""
|
||||
CMAKE_TARGETS="--target ext_server"
|
||||
CMAKE_TARGETS="--target ollama_llama_server"
|
||||
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on ${CMAKE_DEFS}"
|
||||
else
|
||||
@@ -81,27 +81,24 @@ apply_patches() {
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
mkdir -p ${BUILD_DIR}/lib/
|
||||
ls ${BUILD_DIR}
|
||||
g++ -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.${LIB_EXT} \
|
||||
${GCC_ARCH} \
|
||||
${WHOLE_ARCHIVE} ${BUILD_DIR}/ext_server/libext_server.a ${NO_WHOLE_ARCHIVE} \
|
||||
${BUILD_DIR}/common/libcommon.a \
|
||||
${BUILD_DIR}/libllama.a \
|
||||
-Wl,-rpath,\$ORIGIN \
|
||||
-lpthread -ldl -lm \
|
||||
${EXTRA_LIBS}
|
||||
}
|
||||
|
||||
compress_libs() {
|
||||
compress() {
|
||||
echo "Compressing payloads to reduce overall binary size..."
|
||||
pids=""
|
||||
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
|
||||
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
|
||||
gzip -n --best -f ${lib} &
|
||||
rm -rf ${BUILD_DIR}/bin/*.gz
|
||||
for f in ${BUILD_DIR}/bin/* ; do
|
||||
gzip -n --best -f ${f} &
|
||||
pids+=" $!"
|
||||
done
|
||||
echo
|
||||
# check for lib directory
|
||||
if [ -d ${BUILD_DIR}/lib ]; then
|
||||
for f in ${BUILD_DIR}/lib/* ; do
|
||||
gzip -n --best -f ${f} &
|
||||
pids+=" $!"
|
||||
done
|
||||
fi
|
||||
echo
|
||||
for pid in ${pids}; do
|
||||
wait $pid
|
||||
done
|
||||
|
||||
@@ -18,21 +18,31 @@ sign() {
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
|
||||
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
@@ -40,11 +50,11 @@ case "${GOARCH}" in
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
@@ -52,20 +62,30 @@ case "${GOARCH}" in
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
|
||||
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
@@ -75,3 +95,4 @@ case "${GOARCH}" in
|
||||
esac
|
||||
|
||||
cleanup
|
||||
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
|
||||
|
||||
@@ -57,16 +57,31 @@ init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
|
||||
init_vars
|
||||
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
fi
|
||||
|
||||
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring
|
||||
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||
init_vars
|
||||
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||
echo "Building custom CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
else
|
||||
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
@@ -83,11 +98,12 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
if [ "${ARCH}" == "x86_64" ]; then
|
||||
@@ -101,10 +117,10 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
|
||||
@@ -114,10 +130,10 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
@@ -156,8 +172,8 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
# Disabling has minimal performance effect while maintaining compatibility.
|
||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
|
||||
@@ -165,20 +181,20 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
#
|
||||
# TODO - in the future we may shift to packaging these separately and conditionally
|
||||
# downloading them in the install script.
|
||||
DEPS="$(ldd ${BUILD_DIR}/lib/libext_server.so )"
|
||||
DEPS="$(ldd ${BUILD_DIR}/bin/ollama_llama_server )"
|
||||
for lib in libcudart.so libcublas.so libcublasLt.so ; do
|
||||
DEP=$(echo "${DEPS}" | grep ${lib} | cut -f1 -d' ' | xargs || true)
|
||||
if [ -n "${DEP}" -a -e "${CUDA_LIB_DIR}/${DEP}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/bin/"
|
||||
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/bin/"
|
||||
elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/"
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/bin/"
|
||||
else
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/bin/"
|
||||
fi
|
||||
done
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
fi
|
||||
|
||||
@@ -201,23 +217,24 @@ if [ -d "${ROCM_PATH}" ]; then
|
||||
fi
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
|
||||
build
|
||||
|
||||
# Record the ROCM dependencies
|
||||
rm -f "${BUILD_DIR}/lib/deps.txt"
|
||||
touch "${BUILD_DIR}/lib/deps.txt"
|
||||
for dep in $(ldd "${BUILD_DIR}/lib/libext_server.so" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e rocm -e amdgpu -e libtinfo ); do
|
||||
echo "${dep}" >> "${BUILD_DIR}/lib/deps.txt"
|
||||
rm -f "${BUILD_DIR}/bin/deps.txt"
|
||||
touch "${BUILD_DIR}/bin/deps.txt"
|
||||
for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e rocm -e amdgpu -e libtinfo ); do
|
||||
echo "${dep}" >> "${BUILD_DIR}/bin/deps.txt"
|
||||
done
|
||||
# bomb out if for some reason we didn't get a few deps
|
||||
if [ $(cat "${BUILD_DIR}/lib/deps.txt" | wc -l ) -lt 8 ] ; then
|
||||
cat "${BUILD_DIR}/lib/deps.txt"
|
||||
if [ $(cat "${BUILD_DIR}/bin/deps.txt" | wc -l ) -lt 8 ] ; then
|
||||
cat "${BUILD_DIR}/bin/deps.txt"
|
||||
echo "ERROR: deps file short"
|
||||
exit 1
|
||||
fi
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
cleanup
|
||||
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
|
||||
|
||||
@@ -33,7 +33,7 @@ function init_vars {
|
||||
"-DBUILD_SHARED_LIBS=on",
|
||||
"-DLLAMA_NATIVE=off"
|
||||
)
|
||||
$script:cmakeTargets = @("ext_server")
|
||||
$script:cmakeTargets = @("ollama_llama_server")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
||||
@@ -97,16 +97,14 @@ function apply_patches {
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
foreach ($file in $filePaths) {
|
||||
git checkout $file
|
||||
git -C "${script:llamacppDir}" checkout $file
|
||||
}
|
||||
}
|
||||
|
||||
# Apply each patch
|
||||
foreach ($patch in $patches) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git apply $patch.FullName
|
||||
git -C "${script:llamacppDir}" apply $patch.FullName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,41 +113,41 @@ function build {
|
||||
& cmake --version
|
||||
& cmake -S "${script:llamacppDir}" -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })"
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config $($script:cmakeTargets | ForEach-Object { `"--target`", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function install {
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
cp "${script:buildDir}/bin/${script:config}/ext_server.dll" "${script:buildDir}/lib"
|
||||
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
|
||||
# Display the dll dependencies in the build log
|
||||
if ($script:DUMPBIN -ne $null) {
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
|
||||
# Rearrange output to be consistent between different generators
|
||||
if ($null -ne ${script:config} -And (test-path -path "${script:buildDir}/bin/${script:config}" ) ) {
|
||||
mv -force "${script:buildDir}/bin/${script:config}/*" "${script:buildDir}/bin/"
|
||||
remove-item "${script:buildDir}/bin/${script:config}"
|
||||
}
|
||||
}
|
||||
|
||||
function sign {
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
write-host "Signing ${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in (get-childitem "${script:buildDir}/lib/*.dll")){
|
||||
& "${script:SignTool}" sign /v /debug /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
write-host "Signing ${script:buildDir}/bin/*.exe ${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in @(get-childitem "${script:buildDir}/bin/*.exe") + @(get-childitem "${script:buildDir}/bin/*.dll")){
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc "${env:KEY_CONTAINER}" $file
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function compress_libs {
|
||||
function compress {
|
||||
if ($script:GZIP -eq $null) {
|
||||
write-host "gzip not installed, not compressing files"
|
||||
return
|
||||
}
|
||||
write-host "Compressing binaries..."
|
||||
$binaries = dir "${script:buildDir}/bin/*.exe"
|
||||
foreach ($file in $binaries) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
|
||||
write-host "Compressing dlls..."
|
||||
$libs = dir "${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in $libs) {
|
||||
$dlls = dir "${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in $dlls) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
}
|
||||
@@ -164,14 +162,11 @@ function cleanup {
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
foreach ($file in $filePaths) {
|
||||
git checkout $file
|
||||
git -C "${script:llamacppDir}" checkout $file
|
||||
}
|
||||
git -C "${script:llamacppDir}" checkout CMakeLists.txt
|
||||
}
|
||||
Set-Location "${script:llamacppDir}/"
|
||||
git checkout CMakeLists.txt
|
||||
|
||||
}
|
||||
|
||||
init_vars
|
||||
@@ -179,7 +174,6 @@ git_module_setup
|
||||
apply_patches
|
||||
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||
|
||||
@@ -187,32 +181,54 @@ $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||
|
||||
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||
|
||||
# GCC build for direct linking into the Go binary
|
||||
init_vars
|
||||
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
||||
# as we need this to be compiled by gcc for golang to be able to link with itx
|
||||
write-host "Checking for MinGW..."
|
||||
# error action ensures we exit on failure
|
||||
get-command gcc
|
||||
get-command mingw32-make
|
||||
$script:cmakeTargets = @("llama", "ggml")
|
||||
$script:cmakeDefs = @(
|
||||
"-G", "MinGW Makefiles"
|
||||
"-DCMAKE_C_COMPILER=gcc.exe",
|
||||
"-DCMAKE_CXX_COMPILER=g++.exe",
|
||||
"-DBUILD_SHARED_LIBS=off",
|
||||
"-DLLAMA_NATIVE=off",
|
||||
"-DLLAMA_AVX=off",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DLLAMA_AVX512=off",
|
||||
"-DLLAMA_F16C=off",
|
||||
"-DLLAMA_FMA=off")
|
||||
$script:buildDir="../build/windows/${script:ARCH}_static"
|
||||
write-host "Building static library"
|
||||
build
|
||||
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||
write-host "Building LCD CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
||||
write-host "Building AVX CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx2"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
||||
write-host "Building AVX2 CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
} else {
|
||||
write-host "Skipping CPU generation step as requested"
|
||||
}
|
||||
@@ -225,13 +241,11 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
write-host "Building CUDA"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
}
|
||||
|
||||
if ($null -ne $env:HIP_PATH) {
|
||||
@@ -241,12 +255,13 @@ if ($null -ne $env:HIP_PATH) {
|
||||
}
|
||||
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
||||
$script:cmakeDefs += @(
|
||||
"-G", "Ninja",
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||
"-DLLAMA_HIPBLAS=on",
|
||||
"-DHIP_PLATFORM=amd",
|
||||
"-DLLAMA_AVX=on",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
|
||||
@@ -264,13 +279,13 @@ if ($null -ne $env:HIP_PATH) {
|
||||
build
|
||||
# Ninja doesn't prefix with config name
|
||||
${script:config}=""
|
||||
install
|
||||
if ($null -ne $script:DUMPBIN) {
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
|
||||
}
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
}
|
||||
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\llama.cpp\build\windows\${script:ARCH})"
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package generate
|
||||
|
||||
//go:generate sh ./gen_darwin.sh
|
||||
//go:generate bash ./gen_darwin.sh
|
||||
|
||||
86
llm/ggla.go
86
llm/ggla.go
@@ -7,16 +7,18 @@ import (
|
||||
"slices"
|
||||
)
|
||||
|
||||
type ContainerGGLA struct {
|
||||
type containerGGLA struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *ContainerGGLA) Name() string {
|
||||
func (c *containerGGLA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
binary.Read(rs, binary.LittleEndian, &c.version)
|
||||
func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
if err := binary.Read(rs, binary.LittleEndian, &c.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch c.version {
|
||||
case 1:
|
||||
@@ -24,37 +26,45 @@ func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
return nil, errors.New("invalid version")
|
||||
}
|
||||
|
||||
model := newModelGGLA(c)
|
||||
model := newGGLA(c)
|
||||
err := model.decode(rs)
|
||||
return model, err
|
||||
}
|
||||
|
||||
type ModelGGLA struct {
|
||||
*ContainerGGLA
|
||||
type ggla struct {
|
||||
*containerGGLA
|
||||
|
||||
kv KV
|
||||
tensors []Tensor
|
||||
tensors []*Tensor
|
||||
}
|
||||
|
||||
func newModelGGLA(container *ContainerGGLA) *ModelGGLA {
|
||||
return &ModelGGLA{
|
||||
ContainerGGLA: container,
|
||||
func newGGLA(container *containerGGLA) *ggla {
|
||||
return &ggla{
|
||||
containerGGLA: container,
|
||||
kv: make(KV),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
|
||||
func (llm *ggla) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *ggla) Tensors() Tensors {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
func (llm *ggla) decode(rs io.ReadSeeker) error {
|
||||
var r uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
||||
return err
|
||||
}
|
||||
m.kv["r"] = r
|
||||
llm.kv["r"] = r
|
||||
|
||||
var alpha uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &alpha); err != nil {
|
||||
return err
|
||||
}
|
||||
m.kv["alpha"] = alpha
|
||||
llm.kv["alpha"] = alpha
|
||||
|
||||
for {
|
||||
var dims uint32
|
||||
@@ -109,54 +119,10 @@ func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
|
||||
|
||||
t.Offset = uint64(offset)
|
||||
|
||||
if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
|
||||
if _, err := rs.Seek(int64(t.size()), io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.tensors = append(m.tensors, t)
|
||||
llm.tensors = append(llm.tensors, &t)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) KV() KV {
|
||||
return m.kv
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) Tensor() []Tensor {
|
||||
return m.tensors
|
||||
}
|
||||
|
||||
func (*ModelGGLA) ModelFamily() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (*ModelGGLA) ModelType() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) FileType() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumLayers() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumGQA() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumEmbed() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumHead() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumHeadKv() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumCtx() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
256
llm/ggml.go
256
llm/ggml.go
@@ -3,14 +3,14 @@ package llm
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
container
|
||||
model
|
||||
|
||||
Size int64
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -90,15 +90,178 @@ func fileType(fileType uint32) string {
|
||||
}
|
||||
|
||||
type model interface {
|
||||
ModelFamily() string
|
||||
ModelType() string
|
||||
FileType() string
|
||||
NumLayers() uint32
|
||||
NumGQA() uint32
|
||||
NumEmbed() uint32
|
||||
NumHead() uint32
|
||||
NumHeadKv() uint32
|
||||
NumCtx() uint32
|
||||
KV() KV
|
||||
Tensors() Tensors
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) u64(key string) uint64 {
|
||||
switch v := kv[key].(type) {
|
||||
case uint64:
|
||||
return v
|
||||
case uint32:
|
||||
return uint64(v)
|
||||
case float64:
|
||||
return uint64(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
if s, ok := kv["general.architecture"].(string); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return kv.u64("general.parameter_count")
|
||||
}
|
||||
|
||||
func (kv KV) FileType() string {
|
||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
||||
return fileType(uint32(u64))
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (kv KV) BlockCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
||||
return headCountKV
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
type Tensors []*Tensor
|
||||
|
||||
func (ts Tensors) Layers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if parts[0] == "blk" {
|
||||
parts = parts[1:]
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
layers[parts[0]] = make(Layer)
|
||||
}
|
||||
|
||||
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||
}
|
||||
|
||||
return layers
|
||||
}
|
||||
|
||||
type Layer map[string]*Tensor
|
||||
|
||||
func (l Layer) size() (size uint64) {
|
||||
for _, t := range l {
|
||||
size += t.size()
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
Kind uint32 `json:"kind"`
|
||||
Offset uint64 `json:"-"`
|
||||
|
||||
// Shape is the number of elements in each dimension
|
||||
Shape []uint64 `json:"shape"`
|
||||
|
||||
io.WriterTo `json:"-"`
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
switch {
|
||||
case t.Kind < 2:
|
||||
return 1
|
||||
case t.Kind < 10:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) typeSize() uint64 {
|
||||
blockSize := t.blockSize()
|
||||
|
||||
switch t.Kind {
|
||||
case 0: // FP32
|
||||
return 4
|
||||
case 1: // FP16
|
||||
return 2
|
||||
case 2: // Q4_0
|
||||
return 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
return 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
return 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
return 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
return 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
case 15: // Q8_K
|
||||
return 2 + blockSize + 2*blockSize/16
|
||||
case 16: // IQ2_XXS
|
||||
return 2 + 2*blockSize/8
|
||||
case 17: // IQ2_XS
|
||||
return 2 + 2*blockSize/8 + blockSize/32
|
||||
case 18: // IQ3_XXS
|
||||
return 2 + 3*blockSize/8
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) parameters() uint64 {
|
||||
var count uint64 = 1
|
||||
for _, n := range t.Shape {
|
||||
count *= n
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (t Tensor) size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
type container interface {
|
||||
@@ -122,42 +285,91 @@ const (
|
||||
|
||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||
|
||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, error) {
|
||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
var magic uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var c container
|
||||
switch magic {
|
||||
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
|
||||
return nil, ErrUnsupportedFormat
|
||||
return nil, 0, ErrUnsupportedFormat
|
||||
case FILE_MAGIC_GGLA:
|
||||
c = &ContainerGGLA{}
|
||||
c = &containerGGLA{}
|
||||
case FILE_MAGIC_GGUF_LE:
|
||||
c = &ContainerGGUF{ByteOrder: binary.LittleEndian}
|
||||
c = &containerGGUF{ByteOrder: binary.LittleEndian}
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
c = &ContainerGGUF{ByteOrder: binary.BigEndian}
|
||||
c = &containerGGUF{ByteOrder: binary.BigEndian}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
return nil, 0, errors.New("invalid file magic")
|
||||
}
|
||||
|
||||
model, err := c.Decode(rs)
|
||||
if errors.Is(err, io.EOF) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// final model type
|
||||
return &GGML{
|
||||
container: c,
|
||||
model: model,
|
||||
Size: offset,
|
||||
}, nil
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
1243
llm/gguf.go
1243
llm/gguf.go
File diff suppressed because it is too large
Load Diff
Submodule llm/llama.cpp updated: ad3a0505e3...1b67731e18
100
llm/llama.go
100
llm/llama.go
@@ -1,100 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
var payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
|
||||
|
||||
type prediction struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
const maxRetries = 3
|
||||
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
type PredictResult struct {
|
||||
Content string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
231
llm/llm.go
231
llm/llm.go
@@ -1,175 +1,86 @@
|
||||
package llm
|
||||
|
||||
// #cgo CFLAGS: -Illama.cpp
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict(context.Context, PredictOpts, func(PredictResult)) error
|
||||
Embedding(context.Context, string) ([]float64, error)
|
||||
Encode(context.Context, string) ([]int, error)
|
||||
Decode(context.Context, []int) (string, error)
|
||||
Close()
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
func Quantize(infile, outfile, filetype string) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
func New(model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
coutfile := C.CString(outfile)
|
||||
defer C.free(unsafe.Pointer(coutfile))
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
params := C.llama_model_quantize_default_params()
|
||||
params.nthread = -1
|
||||
|
||||
ggml, err := DecodeGGML(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.NumCtx > int(ggml.NumCtx()) {
|
||||
slog.Warn(fmt.Sprintf("requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx()))
|
||||
opts.NumCtx = int(ggml.NumCtx())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
vram, _ := gpu.CheckVRAM()
|
||||
size := ggml.Size
|
||||
|
||||
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(max(ggml.NumHead(), 1))
|
||||
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.NumGQA()) * kv / 6
|
||||
|
||||
// certain model architectures don't support gpu inference yet
|
||||
if slices.Contains(cpuOnlyFamilies, ggml.ModelFamily()) {
|
||||
opts.NumGPU = 0
|
||||
}
|
||||
|
||||
info := gpu.GetGPUInfo()
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if opts.NumGPU == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if size+kv+graph > vram {
|
||||
slog.Info("not enough vram available, setting num_gpu=0")
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// TODO: implement layer splitting on macOS
|
||||
opts.NumGPU = 999
|
||||
switch filetype {
|
||||
case "F32":
|
||||
params.ftype = fileTypeF32
|
||||
case "F16":
|
||||
params.ftype = fileTypeF16
|
||||
case "Q4_0":
|
||||
params.ftype = fileTypeQ4_0
|
||||
case "Q4_1":
|
||||
params.ftype = fileTypeQ4_1
|
||||
case "Q4_1_F16":
|
||||
params.ftype = fileTypeQ4_1_F16
|
||||
case "Q8_0":
|
||||
params.ftype = fileTypeQ8_0
|
||||
case "Q5_0":
|
||||
params.ftype = fileTypeQ5_0
|
||||
case "Q5_1":
|
||||
params.ftype = fileTypeQ5_1
|
||||
case "Q2_K":
|
||||
params.ftype = fileTypeQ2_K
|
||||
case "Q3_K_S":
|
||||
params.ftype = fileTypeQ3_K_S
|
||||
case "Q3_K_M":
|
||||
params.ftype = fileTypeQ3_K_M
|
||||
case "Q3_K_L":
|
||||
params.ftype = fileTypeQ3_K_L
|
||||
case "Q4_K_S":
|
||||
params.ftype = fileTypeQ4_K_S
|
||||
case "Q4_K_M":
|
||||
params.ftype = fileTypeQ4_K_M
|
||||
case "Q5_K_S":
|
||||
params.ftype = fileTypeQ5_K_S
|
||||
case "Q5_K_M":
|
||||
params.ftype = fileTypeQ5_K_M
|
||||
case "Q6_K":
|
||||
params.ftype = fileTypeQ6_K
|
||||
case "IQ2_XXS":
|
||||
params.ftype = fileTypeIQ2_XXS
|
||||
case "IQ2_XS":
|
||||
params.ftype = fileTypeIQ2_XS
|
||||
case "Q2_K_S":
|
||||
params.ftype = fileTypeQ2_K_S
|
||||
case "Q3_K_XS":
|
||||
params.ftype = fileTypeQ3_K_XS
|
||||
case "IQ3_XXS":
|
||||
params.ftype = fileTypeIQ3_XXS
|
||||
default:
|
||||
if info.Library == "cpu" {
|
||||
slog.Info("GPU not available, falling back to CPU")
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// don't use GPU at all if no layers are loaded
|
||||
if opts.NumGPU == 0 {
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
break
|
||||
}
|
||||
|
||||
// user-defined GPU count
|
||||
if opts.NumGPU != -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// the "main" GPU needs the most memory and determines the limit
|
||||
// of how many layers can be loaded. It needs to fit:
|
||||
// 1. the full compute graph allocation for all devices (graph)
|
||||
// 2. the proportional kv cache for all devices (kv * % layers)
|
||||
// 3. the proportional model (size * % layers / # devices)
|
||||
// This estimates the number of layers
|
||||
maxlayers := int64(ggml.NumLayers()) + 1
|
||||
devices := int64(info.DeviceCount)
|
||||
avg := vram / devices
|
||||
layers := maxlayers * (avg - graph) / (kv + size/devices)
|
||||
if layers > maxlayers {
|
||||
layers = maxlayers
|
||||
}
|
||||
|
||||
// 1 + 2 must fit on the main gpu
|
||||
min := graph + kv*layers/maxlayers
|
||||
if layers <= 0 || min > avg {
|
||||
slog.Info("not enough vram available, falling back to CPU only")
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
opts.NumGPU = int(layers)
|
||||
return fmt.Errorf("unknown filetype: %s", filetype)
|
||||
}
|
||||
|
||||
opts.RopeFrequencyBase = 0.0
|
||||
opts.RopeFrequencyScale = 0.0
|
||||
return newLlmServer(info, model, adapters, projectors, opts)
|
||||
}
|
||||
|
||||
// Give any native cgo implementations an opportunity to initialize
|
||||
func Init() error {
|
||||
return nativeInit()
|
||||
}
|
||||
|
||||
func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
dynLibs := getDynLibs(gpuInfo)
|
||||
|
||||
// Check to see if the user has requested a specific library instead of auto-detecting
|
||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
||||
if demandLib != "" {
|
||||
libPath := availableDynLibs[demandLib]
|
||||
if libPath == "" {
|
||||
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib))
|
||||
dynLibs = []string{libPath}
|
||||
}
|
||||
}
|
||||
|
||||
// We stage into a temp directory, and if we've been idle for a while, it may have been reaped
|
||||
_, err := os.Stat(dynLibs[0])
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("%s has disappeared, reloading libraries", dynLibs[0]))
|
||||
err = nativeInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err2 := fmt.Errorf("unable to locate suitable llm library")
|
||||
for _, dynLib := range dynLibs {
|
||||
srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
|
||||
if err == nil {
|
||||
return srv, nil
|
||||
}
|
||||
slog.Warn(fmt.Sprintf("Failed to load dynamic library %s %s", dynLib, err))
|
||||
err2 = err
|
||||
}
|
||||
|
||||
return nil, err2
|
||||
if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", retval)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,5 +4,5 @@ import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/linux/*/*/lib/*
|
||||
//go:embed build/darwin/x86_64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
@@ -4,5 +4,5 @@ import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/windows/*/*/lib/*.dll*
|
||||
//go:embed build/darwin/arm64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
6
llm/llm_linux.go
Normal file
6
llm/llm_linux.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed build/linux/*/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
6
llm/llm_windows.go
Normal file
6
llm/llm_windows.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed build/windows/*/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,13 +0,0 @@
|
||||
diff --git a/llama.cpp b/llama.cpp
|
||||
index b27aa272..99372f9c 100644
|
||||
--- a/llama.cpp
|
||||
+++ b/llama.cpp
|
||||
@@ -9360,7 +9360,7 @@ struct llm_tokenizer_wpm {
|
||||
}
|
||||
|
||||
uint32_t to_lower(uint32_t code) {
|
||||
- static const std::locale locale("en_US.UTF-8");
|
||||
+ static const std::locale locale("");
|
||||
#if defined(_WIN32)
|
||||
if (code > 0xFFFF) {
|
||||
return code;
|
||||
211
llm/payload.go
Normal file
211
llm/payload.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
|
||||
|
||||
func Init() error {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("extracting embedded files", "dir", payloadsDir)
|
||||
binGlob := "build/*/*/*/bin/*"
|
||||
|
||||
// extract server libraries
|
||||
err = extractFiles(payloadsDir, binGlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extract binaries: %v", err)
|
||||
}
|
||||
|
||||
var variants []string
|
||||
for v := range availableServers() {
|
||||
variants = append(variants, v)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// binary names may contain an optional variant separated by '_'
|
||||
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
|
||||
// Any library without a variant is the lowest common denominator
|
||||
func availableServers() map[string]string {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
slog.Error("payload lookup error", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// glob payloadsDir for files that start with ollama_
|
||||
pattern := filepath.Join(payloadsDir, "*")
|
||||
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
slog.Debug("could not glob", "pattern", pattern, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
servers := make(map[string]string)
|
||||
for _, file := range files {
|
||||
slog.Debug("availableServers : found", "file", file)
|
||||
servers[filepath.Base(file)] = file
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// serversForGpu returns a list of compatible servers give the provided GPU
|
||||
// info, ordered by performance. assumes Init() has been called
|
||||
// TODO - switch to metadata based mapping
|
||||
func serversForGpu(info gpu.GpuInfo) []string {
|
||||
// glob workDir for files that start with ollama_
|
||||
availableServers := availableServers()
|
||||
requested := info.Library
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
|
||||
servers := []string{}
|
||||
|
||||
// exact match first
|
||||
for a := range availableServers {
|
||||
if a == requested {
|
||||
servers = []string{a}
|
||||
|
||||
if a == "metal" {
|
||||
return servers
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
alt := []string{}
|
||||
|
||||
// Then for GPUs load alternates and sort the list for consistent load ordering
|
||||
if info.Library != "cpu" {
|
||||
for a := range availableServers {
|
||||
if info.Library == strings.Split(a, "_")[0] && a != requested {
|
||||
alt = append(alt, a)
|
||||
}
|
||||
}
|
||||
|
||||
slices.Sort(alt)
|
||||
servers = append(servers, alt...)
|
||||
}
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if info.Library != "cpu" {
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != "" {
|
||||
for cmp := range availableServers {
|
||||
if cmp == "cpu_"+variant {
|
||||
servers = append(servers, cmp)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
servers = append(servers, "cpu")
|
||||
}
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
servers = []string{"cpu"}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// extract extracts the embedded files to the target directory
|
||||
func extractFiles(targetDir string, glob string) error {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return errPayloadMissing
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
|
||||
}
|
||||
|
||||
g := new(errgroup.Group)
|
||||
|
||||
// build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
|
||||
for _, file := range files {
|
||||
filename := file
|
||||
|
||||
variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
|
||||
|
||||
slog.Debug("extracting", "variant", variant, "file", filename)
|
||||
|
||||
g.Go(func() error {
|
||||
srcf, err := libEmbed.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcf.Close()
|
||||
|
||||
src := io.Reader(srcf)
|
||||
if strings.HasSuffix(filename, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", filename, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
variantDir := filepath.Join(targetDir, variant)
|
||||
if err := os.MkdirAll(variantDir, 0o755); err != nil {
|
||||
return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
|
||||
}
|
||||
|
||||
base := filepath.Base(filename)
|
||||
destFilename := filepath.Join(variantDir, base)
|
||||
|
||||
_, err = os.Stat(destFilename)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", filename, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", filename, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", filename, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
// If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
|
||||
gpu.Cleanup()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
// Libraries names may contain an optional variant separated by '_'
|
||||
// For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2"
|
||||
// Any library without a variant is the lowest common denominator
|
||||
var availableDynLibs = map[string]string{}
|
||||
|
||||
const pathComponentCount = 7
|
||||
|
||||
// getDynLibs returns an ordered list of LLM libraries to try, starting with the best
|
||||
func getDynLibs(gpuInfo gpu.GpuInfo) []string {
|
||||
// Short circuit if we know we're using the default built-in (darwin only)
|
||||
if gpuInfo.Library == "default" {
|
||||
return []string{"default"}
|
||||
}
|
||||
// TODO - temporary until we have multiple CPU variations for Darwin
|
||||
// Short circuit on darwin with metal only
|
||||
if len(availableDynLibs) == 1 {
|
||||
if _, onlyMetal := availableDynLibs["metal"]; onlyMetal {
|
||||
return []string{availableDynLibs["metal"]}
|
||||
}
|
||||
}
|
||||
|
||||
exactMatch := ""
|
||||
dynLibs := []string{}
|
||||
altDynLibs := []string{}
|
||||
requested := gpuInfo.Library
|
||||
if gpuInfo.Variant != "" {
|
||||
requested += "_" + gpuInfo.Variant
|
||||
}
|
||||
// Try to find an exact match
|
||||
for cmp := range availableDynLibs {
|
||||
if requested == cmp {
|
||||
exactMatch = cmp
|
||||
dynLibs = []string{availableDynLibs[cmp]}
|
||||
break
|
||||
}
|
||||
}
|
||||
// Then for GPUs load alternates and sort the list for consistent load ordering
|
||||
if gpuInfo.Library != "cpu" {
|
||||
for cmp := range availableDynLibs {
|
||||
if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch {
|
||||
altDynLibs = append(altDynLibs, cmp)
|
||||
}
|
||||
}
|
||||
slices.Sort(altDynLibs)
|
||||
for _, altDynLib := range altDynLibs {
|
||||
dynLibs = append(dynLibs, availableDynLibs[altDynLib])
|
||||
}
|
||||
}
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if gpuInfo.Library != "cpu" {
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != "" {
|
||||
for cmp := range availableDynLibs {
|
||||
if cmp == "cpu_"+variant {
|
||||
dynLibs = append(dynLibs, availableDynLibs[cmp])
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dynLibs = append(dynLibs, availableDynLibs["cpu"])
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if we didn't find any matches, LCD CPU FTW
|
||||
if len(dynLibs) == 0 {
|
||||
dynLibs = []string{availableDynLibs["cpu"]}
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("ordered list of LLM libraries to try %v", dynLibs))
|
||||
return dynLibs
|
||||
}
|
||||
|
||||
func rocmDynLibPresent() bool {
|
||||
for dynLibName := range availableDynLibs {
|
||||
if strings.HasPrefix(dynLibName, "rocm") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func nativeInit() error {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("Extracting dynamic libraries to %s ...", payloadsDir))
|
||||
|
||||
libs, err := extractDynamicLibs(payloadsDir, "llama.cpp/build/*/*/*/lib/*")
|
||||
if err != nil {
|
||||
if errors.Is(err, payloadMissing) {
|
||||
slog.Info(fmt.Sprintf("%s", payloadMissing))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, lib := range libs {
|
||||
// The last dir component is the variant name
|
||||
variant := filepath.Base(filepath.Dir(lib))
|
||||
availableDynLibs[variant] = lib
|
||||
}
|
||||
|
||||
if err := verifyDriverAccess(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report which dynamic libraries we have loaded to assist troubleshooting
|
||||
variants := make([]string, len(availableDynLibs))
|
||||
i := 0
|
||||
for variant := range availableDynLibs {
|
||||
variants[i] = variant
|
||||
i++
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDynamicLibs(payloadsDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var libs []string
|
||||
var g errgroup.Group
|
||||
for _, file := range files {
|
||||
pathComps := strings.Split(file, "/")
|
||||
if len(pathComps) != pathComponentCount {
|
||||
slog.Error(fmt.Sprintf("unexpected payload components: %v", pathComps))
|
||||
continue
|
||||
}
|
||||
|
||||
file := file
|
||||
g.Go(func() error {
|
||||
// llama.cpp/build/$OS/$GOARCH/$VARIANT/lib/$LIBRARY
|
||||
// Include the variant in the path to avoid conflicts between multiple server libs
|
||||
targetDir := filepath.Join(payloadsDir, pathComps[pathComponentCount-3])
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload lib dir %s: %v", payloadsDir, err)
|
||||
}
|
||||
src := io.Reader(srcFile)
|
||||
filename := file
|
||||
if strings.HasSuffix(file, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", file, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
destFile := filepath.Join(targetDir, filepath.Base(filename))
|
||||
if strings.Contains(destFile, "server") {
|
||||
mu.Lock()
|
||||
libs = append(libs, destFile)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
destFp, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFp.Close()
|
||||
if _, err := io.Copy(destFp, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
// If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
|
||||
gpu.Cleanup()
|
||||
return nil, err
|
||||
}
|
||||
return libs, nil
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if rocmDynLibPresent() {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/darwin/x86_64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,8 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/arm64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,58 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDynLibs(t *testing.T) {
|
||||
availableDynLibs = map[string]string{
|
||||
"cpu": "X_cpu",
|
||||
}
|
||||
assert.Equal(t, false, rocmDynLibPresent())
|
||||
res := getDynLibs(gpu.GpuInfo{Library: "cpu"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"], res[0])
|
||||
|
||||
variant := gpu.GetCPUVariant()
|
||||
if variant != "" {
|
||||
variant = "_" + variant
|
||||
}
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm_v5": "X_rocm_v5",
|
||||
"rocm_v6": "X_rocm_v6",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[0])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "default"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "default", res[0])
|
||||
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm": "X_rocm_v5",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 2)
|
||||
assert.Equal(t, availableDynLibs["rocm"], res[0])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[1])
|
||||
}
|
||||
850
llm/server.go
Normal file
850
llm/server.go
Normal file
@@ -0,0 +1,850 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
// LlamaServer is an instance of the llama.cpp server
|
||||
type LlamaServer struct {
|
||||
port int
|
||||
cmd *exec.Cmd
|
||||
done chan error // Channel to signal when the process exits
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
|
||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
memoryAvailable, _ := gpu.CheckVRAM()
|
||||
info := gpu.GetGPUInfo()
|
||||
|
||||
memoryMinimum := info.MinimumMemory
|
||||
for _, projector := range projectors {
|
||||
memoryMinimum += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||
|
||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
|
||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
|
||||
if info.Library != "metal" {
|
||||
if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
}
|
||||
|
||||
var layerCount int
|
||||
layers := ggml.Tensors().Layers()
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
memoryLayer := layers[fmt.Sprintf("%d", i)].size()
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
memoryLayer += kv / ggml.KV().BlockCount()
|
||||
|
||||
memoryRequiredTotal += memoryLayer
|
||||
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
||||
memoryRequiredPartial += memoryLayer
|
||||
layerCount++
|
||||
}
|
||||
}
|
||||
|
||||
memoryLayerOutput := layers["output"].size()
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
if memoryAvailable > memoryRequiredTotal {
|
||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||
memoryRequiredPartial = memoryRequiredTotal
|
||||
}
|
||||
|
||||
if opts.NumGPU < 0 {
|
||||
opts.NumGPU = layerCount
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"offload to gpu",
|
||||
"reallayers", opts.NumGPU,
|
||||
"layers", layerCount,
|
||||
"required", format.HumanBytes2(memoryRequiredTotal),
|
||||
"used", format.HumanBytes2(memoryRequiredPartial),
|
||||
"available", format.HumanBytes2(memoryAvailable),
|
||||
"kv", format.HumanBytes2(kv),
|
||||
"fulloffload", format.HumanBytes2(graphFullOffload),
|
||||
"partialoffload", format.HumanBytes2(graphPartialOffload),
|
||||
)
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
|
||||
availableServers := availableServers()
|
||||
servers := serversForGpu(info)
|
||||
|
||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
||||
if demandLib != "" {
|
||||
serverPath := availableServers[demandLib]
|
||||
if serverPath == "" {
|
||||
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
|
||||
} else {
|
||||
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
|
||||
servers = []string{demandLib}
|
||||
}
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no servers found for %v", info)
|
||||
}
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||
"--embedding",
|
||||
}
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
params = append(params, "--log-format", "json")
|
||||
} else {
|
||||
params = append(params, "--log-disable")
|
||||
}
|
||||
|
||||
if opts.NumGPU >= 0 {
|
||||
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
||||
}
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
params = append(params, "--verbose")
|
||||
}
|
||||
|
||||
if opts.MainGPU > 0 {
|
||||
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
|
||||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
|
||||
}
|
||||
|
||||
if !opts.F16KV {
|
||||
params = append(params, "--memory-f32")
|
||||
}
|
||||
|
||||
if opts.UseMLock {
|
||||
params = append(params, "--mlock")
|
||||
}
|
||||
|
||||
if !opts.UseMMap {
|
||||
params = append(params, "--no-mmap")
|
||||
}
|
||||
|
||||
if opts.UseNUMA {
|
||||
params = append(params, "--numa")
|
||||
}
|
||||
|
||||
// Loop through potential servers
|
||||
var finalErr error
|
||||
for i := 0; i < len(servers); i++ {
|
||||
dir := availableServers[servers[i]]
|
||||
|
||||
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed ", "error", err)
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := append(params, "--port", strconv.Itoa(port))
|
||||
|
||||
pathEnv := "LD_LIBRARY_PATH"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathEnv = "PATH"
|
||||
}
|
||||
// append the server directory to LD_LIBRARY_PATH/PATH
|
||||
libraryPaths := []string{dir}
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
// Append our runner directory to the path
|
||||
// This will favor system libraries over our bundled library dependencies
|
||||
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
||||
}
|
||||
|
||||
server := filepath.Join(dir, "ollama_llama_server")
|
||||
if runtime.GOOS == "windows" {
|
||||
server = server + ".exe"
|
||||
}
|
||||
|
||||
s := &LlamaServer{
|
||||
port: port,
|
||||
cmd: exec.Command(server, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
}
|
||||
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug(libEnv)
|
||||
s.cmd.Env = append(os.Environ(), libEnv)
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
|
||||
if err = s.cmd.Start(); err != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
err = fmt.Errorf("error starting the external llama server: %v %s", err, msg)
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
// Exit status managed via getServerStatus
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
if err = s.waitUntilRunning(); err != nil {
|
||||
slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
s.Close()
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
slog.Error("unable to load any llama server", "error", finalErr)
|
||||
return nil, finalErr
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) uint64 {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(file)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
for _, layer := range ggml.Tensors().Layers() {
|
||||
mem += layer.size()
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const ( // iota is reset to 0
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusNoSlotsAvaialble
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusNotResponding
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
// Fail fast if its exited
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
|
||||
if err != nil {
|
||||
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ServerStatusNotResponding, fmt.Errorf("server not responding")
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||
}
|
||||
|
||||
var status ServerStatusResp
|
||||
if err := json.Unmarshal(body, &status); err != nil {
|
||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvaialble, nil
|
||||
case "loading model":
|
||||
return ServerStatusLoadingModel, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||
_, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
slog.Debug("server unhealthy", "error", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) waitUntilRunning() error {
|
||||
start := time.Now()
|
||||
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
slog.Info("waiting for llama runner to start responding")
|
||||
var lastStatus ServerStatus = -1
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||
case <-ticker.C:
|
||||
if time.Now().After(expiresAt) {
|
||||
// timeout
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %s", msg)
|
||||
}
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil && lastStatus != status {
|
||||
slog.Debug("server not yet available", "error", err)
|
||||
lastStatus = status
|
||||
continue
|
||||
}
|
||||
|
||||
switch status {
|
||||
case ServerStatusLoadingModel:
|
||||
// TODO - this state never seems to happen with the current server.cpp code (bug?)
|
||||
// it doesn't respond to the health endpoint until after the model is loaded
|
||||
slog.Debug("loading model")
|
||||
case ServerStatusReady:
|
||||
slog.Debug(fmt.Sprintf("llama runner started in %f seconds", time.Since(start).Seconds()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxRetries = 3
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": req.Options.NumPredict,
|
||||
"n_keep": req.Options.NumKeep,
|
||||
"main_gpu": req.Options.MainGPU,
|
||||
"temperature": req.Options.Temperature,
|
||||
"top_k": req.Options.TopK,
|
||||
"top_p": req.Options.TopP,
|
||||
"tfs_z": req.Options.TFSZ,
|
||||
"typical_p": req.Options.TypicalP,
|
||||
"repeat_last_n": req.Options.RepeatLastN,
|
||||
"repeat_penalty": req.Options.RepeatPenalty,
|
||||
"presence_penalty": req.Options.PresencePenalty,
|
||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||
"mirostat": req.Options.Mirostat,
|
||||
"mirostat_tau": req.Options.MirostatTau,
|
||||
"mirostat_eta": req.Options.MirostatEta,
|
||||
"penalize_nl": req.Options.PenalizeNewline,
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
}
|
||||
|
||||
retryDelay := 100 * time.Microsecond
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
retryDelay *= 2 // exponential backoff
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
|
||||
retryNeeded := false
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// try again on slot unavailable
|
||||
if bytes.Contains(line, []byte("slot unavailable")) {
|
||||
retryNeeded = true
|
||||
break
|
||||
}
|
||||
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
s.Close()
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
|
||||
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
||||
}
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
if !retryNeeded {
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady {
|
||||
return "", fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Close() error {
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
return s.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
42
llm/status.go
Normal file
42
llm/status.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||
type StatusWriter struct {
|
||||
LastErrMsg string
|
||||
out *os.File
|
||||
}
|
||||
|
||||
func NewStatusWriter(out *os.File) *StatusWriter {
|
||||
return &StatusWriter{
|
||||
out: out,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO - regex matching to detect errors like
|
||||
// libcublasLt.so.11: cannot open shared object file: No such file or directory
|
||||
|
||||
var errorPrefixes = []string{
|
||||
"error:",
|
||||
"CUDA error",
|
||||
"cudaMalloc failed",
|
||||
"\"ERR\"",
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
var errMsg string
|
||||
for _, prefix := range errorPrefixes {
|
||||
if _, after, ok := bytes.Cut(b, []byte(prefix)); ok {
|
||||
errMsg = prefix + string(bytes.TrimSpace(after))
|
||||
}
|
||||
}
|
||||
if errMsg != "" {
|
||||
w.LastErrMsg = errMsg
|
||||
}
|
||||
|
||||
return w.out.Write(b)
|
||||
}
|
||||
15
llm/utils.go
15
llm/utils.go
@@ -1,15 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
@@ -14,7 +14,7 @@ go build .
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```
|
||||
cd app
|
||||
cd macapp
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
@@ -142,7 +142,9 @@ func (h *History) Save() error {
|
||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
||||
v, _ := h.Buf.Get(cnt)
|
||||
line, _ := v.([]rune)
|
||||
buf.WriteString(string(line) + "\n")
|
||||
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
buf.Flush()
|
||||
f.Close()
|
||||
|
||||
@@ -10,7 +10,7 @@ export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=$V
|
||||
# For developers, you can override the DOCKER_ORG to generate multiarch manifests
|
||||
# DOCKER_ORG=jdoe PUSH=1 ./scripts/build_docker.sh
|
||||
DOCKER_ORG=${DOCKER_ORG:-"ollama"}
|
||||
ARCH_IMAGE_REPO=${ARCH_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
RELEASE_IMAGE_REPO=${RELEASE_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
FINAL_IMAGE_REPO=${FINAL_IMAGE_REPO:-"${DOCKER_ORG}/ollama"}
|
||||
|
||||
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
|
||||
@@ -25,7 +25,7 @@ OLLAMA_SKIP_IMAGE_BUILD=${OLLAMA_SKIP_IMAGE_BUILD:-""}
|
||||
if [ -z "${PUSH}" ] ; then
|
||||
LOAD_OR_PUSH="--load"
|
||||
else
|
||||
echo "Will be pushing ${ARCH_IMAGE_REPO}:$VERSION for ${BUILD_ARCH}"
|
||||
echo "Will be pushing ${RELEASE_IMAGE_REPO}:$VERSION for ${BUILD_ARCH}"
|
||||
LOAD_OR_PUSH="--push"
|
||||
fi
|
||||
|
||||
@@ -37,7 +37,7 @@ if [ -z "${OLLAMA_SKIP_IMAGE_BUILD}" ]; then
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
-f Dockerfile \
|
||||
-t ${ARCH_IMAGE_REPO}:$VERSION-${TARGETARCH} \
|
||||
-t ${RELEASE_IMAGE_REPO}:$VERSION-${TARGETARCH} \
|
||||
.
|
||||
done
|
||||
|
||||
@@ -49,7 +49,7 @@ if [ -z "${OLLAMA_SKIP_IMAGE_BUILD}" ]; then
|
||||
--build-arg=GOFLAGS \
|
||||
--target runtime-rocm \
|
||||
-f Dockerfile \
|
||||
-t ${ARCH_IMAGE_REPO}:$VERSION-rocm \
|
||||
-t ${RELEASE_IMAGE_REPO}:$VERSION-rocm \
|
||||
.
|
||||
fi
|
||||
fi
|
||||
@@ -57,21 +57,21 @@ fi
|
||||
if [ -z "${OLLAMA_SKIP_MANIFEST_CREATE}" ]; then
|
||||
if [ -n "${PUSH}" ]; then
|
||||
docker manifest create ${FINAL_IMAGE_REPO}:$VERSION \
|
||||
${ARCH_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${ARCH_IMAGE_REPO}:$VERSION-arm64
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-arm64
|
||||
docker manifest push ${FINAL_IMAGE_REPO}:$VERSION
|
||||
|
||||
# For symmetry, tag/push the rocm image
|
||||
if [ "${ARCH_IMAGE_REPO}" != "${FINAL_IMAGE_REPO}" ]; then
|
||||
if [ "${RELEASE_IMAGE_REPO}" != "${FINAL_IMAGE_REPO}" ]; then
|
||||
echo "Tagging and pushing rocm image"
|
||||
docker pull ${ARCH_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${ARCH_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
docker pull ${RELEASE_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${RELEASE_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
docker push ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
fi
|
||||
else
|
||||
echo "Skipping manifest generation when not pushing images are available locally as "
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-amd64"
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-arm64"
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-rocm"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-amd64"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-arm64"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-rocm"
|
||||
fi
|
||||
fi
|
||||
|
||||
33
scripts/tag_latest.sh
Executable file
33
scripts/tag_latest.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
# We use 2 different image repositories to handle combining architecture images into multiarch manifest
|
||||
# (The ROCm image is x86 only and is not a multiarch manifest)
|
||||
# For developers, you can override the DOCKER_ORG to generate multiarch manifests
|
||||
# DOCKER_ORG=jdoe VERSION=0.1.30 PUSH=1 ./scripts/tag_latest.sh
|
||||
DOCKER_ORG=${DOCKER_ORG:-"ollama"}
|
||||
RELEASE_IMAGE_REPO=${RELEASE_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
FINAL_IMAGE_REPO=${FINAL_IMAGE_REPO:-"${DOCKER_ORG}/ollama"}
|
||||
|
||||
# Set PUSH to a non-empty string to trigger push instead of load
|
||||
PUSH=${PUSH:-""}
|
||||
|
||||
echo "Assembling manifest and tagging latest"
|
||||
docker manifest rm ${FINAL_IMAGE_REPO}:latest || true
|
||||
docker manifest create ${FINAL_IMAGE_REPO}:latest \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-arm64
|
||||
|
||||
docker pull ${RELEASE_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${RELEASE_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:rocm
|
||||
|
||||
if [ -n "${PUSH}" ]; then
|
||||
echo "Pushing latest tags up..."
|
||||
docker manifest push ${FINAL_IMAGE_REPO}:latest
|
||||
docker push ${FINAL_IMAGE_REPO}:rocm
|
||||
else
|
||||
echo "Not pushing ${FINAL_IMAGE_REPO}:latest and ${FINAL_IMAGE_REPO}:rocm"
|
||||
fi
|
||||
|
||||
|
||||
@@ -247,7 +247,8 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
|
||||
106
server/images.go
106
server/images.go
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -283,7 +284,7 @@ func realpath(mfDir, from string) string {
|
||||
return abspath
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
@@ -321,7 +322,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
pathName := realpath(modelFileDir, c.Args)
|
||||
|
||||
ggufName, err := convertSafetensors(name, pathName)
|
||||
ggufName, err := convertSafetensors(name, pathName, fn)
|
||||
if err != nil {
|
||||
var pathErr *fs.PathError
|
||||
switch {
|
||||
@@ -337,6 +338,26 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
if ggufName != "" {
|
||||
pathName = ggufName
|
||||
defer os.RemoveAll(ggufName)
|
||||
|
||||
if quantization != "" {
|
||||
quantization = strings.ToUpper(quantization)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
|
||||
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(tempfile.Name())
|
||||
|
||||
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pathName = tempfile.Name()
|
||||
}
|
||||
}
|
||||
|
||||
bin, err := os.Open(pathName)
|
||||
@@ -419,34 +440,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
defer bin.Close()
|
||||
|
||||
var offset int64
|
||||
CREATE:
|
||||
for {
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bin.Seek(offset, io.SeekStart)
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
break CREATE
|
||||
case errors.Is(err, llm.ErrUnsupportedFormat):
|
||||
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
ggml, size, err := llm.DecodeGGML(bin)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
|
||||
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.SetModelFormat(ggml.Name())
|
||||
config.SetModelFamily(ggml.ModelFamily())
|
||||
config.SetModelType(ggml.ModelType())
|
||||
config.SetFileType(ggml.FileType())
|
||||
config.SetModelFamily(ggml.KV().Architecture())
|
||||
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
|
||||
config.SetFileType(ggml.KV().FileType())
|
||||
|
||||
mediatype := mediatype
|
||||
if ggml.ModelFamily() == "clip" {
|
||||
if ggml.KV().Architecture() == "clip" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(bin, offset, ggml.Size)
|
||||
sr := io.NewSectionReader(bin, offset, size)
|
||||
layer, err := NewLayer(sr, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -454,7 +473,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
layers.Add(layer)
|
||||
|
||||
offset += ggml.Size
|
||||
offset += size
|
||||
}
|
||||
case "adapter":
|
||||
if strings.HasPrefix(c.Args, "@") {
|
||||
@@ -473,12 +492,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
_, size, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(bin, 0, ggml.Size)
|
||||
sr := io.NewSectionReader(bin, 0, size)
|
||||
layer, err := NewLayer(sr, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -550,13 +569,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
}
|
||||
|
||||
// xxx - can this be removed?
|
||||
if config.ModelType == "65B" {
|
||||
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
|
||||
config.ModelType = "70B"
|
||||
}
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
|
||||
return err
|
||||
@@ -621,8 +633,8 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertSafetensors(name, fn string) (string, error) {
|
||||
r, err := zip.OpenReader(fn)
|
||||
func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
r, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -634,6 +646,7 @@ func convertSafetensors(name, fn string) (string, error) {
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
||||
for _, f := range r.File {
|
||||
fpath := filepath.Join(tempDir, f.Name)
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
@@ -660,32 +673,27 @@ func convertSafetensors(name, fn string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
SupportedArchs := []string{
|
||||
"MistralForCausalLM",
|
||||
}
|
||||
|
||||
for _, arch := range params.Architectures {
|
||||
if !slices.Contains(SupportedArchs, arch) {
|
||||
return "", fmt.Errorf("this safetensors model is not yet supported")
|
||||
}
|
||||
}
|
||||
|
||||
t, err := convert.GetSafeTensors(tempDir)
|
||||
mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
vocab, err := convert.LoadTokens(tempDir)
|
||||
fn(api.ProgressResponse{Status: "processing safetensors"})
|
||||
if err := mArch.GetTensors(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := mArch.LoadVocab(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "converting model"})
|
||||
path, err = mArch.WriteGGUF()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn, err = convert.WriteGGUF(name, t, params, vocab)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
|
||||
112
server/routes.go
112
server/routes.go
@@ -56,12 +56,13 @@ func init() {
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
runner llm.LLM
|
||||
llama *llm.LlamaServer
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
|
||||
*Model
|
||||
model string
|
||||
adapters []string
|
||||
projectors []string
|
||||
*api.Options
|
||||
}
|
||||
|
||||
@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
|
||||
ctx, cancel := context.WithTimeout(c, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
needLoad := loaded.llama == nil || // is there a model loaded?
|
||||
loaded.model != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
|
||||
loaded.llama.Ping(ctx) != nil
|
||||
|
||||
if needLoad {
|
||||
if loaded.runner != nil {
|
||||
if loaded.llama != nil {
|
||||
slog.Info("changing loaded model")
|
||||
loaded.runner.Close()
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.llama.Close()
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
return err
|
||||
}
|
||||
|
||||
loaded.Model = model
|
||||
loaded.runner = llmRunner
|
||||
loaded.model = model.ModelPath
|
||||
loaded.adapters = model.AdapterPaths
|
||||
loaded.projectors = model.ProjectorPaths
|
||||
loaded.llama = llama
|
||||
loaded.Options = &opts
|
||||
}
|
||||
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
if time.Now().Before(loaded.expireAt) {
|
||||
return
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
}
|
||||
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
})
|
||||
}
|
||||
@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
sb.Reset()
|
||||
if req.Context != nil {
|
||||
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
// Build up the full response
|
||||
@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// TODO (jmorganca): encode() should not strip special tokens
|
||||
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
|
||||
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
req := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
@@ -642,7 +647,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -908,6 +913,24 @@ func HeadBlobHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func CreateBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
// noop
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
default:
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1013,16 +1036,14 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
var origins []string
|
||||
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
|
||||
origins = strings.Split(o, ",")
|
||||
}
|
||||
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
|
||||
config.AllowOrigins = origins
|
||||
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
|
||||
config.AllowOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
|
||||
for _, allowOrigin := range defaultAllowOrigins {
|
||||
config.AllowOrigins = append(config.AllowOrigins,
|
||||
fmt.Sprintf("http://%s", allowOrigin),
|
||||
@@ -1125,8 +1146,8 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
gpu.Cleanup()
|
||||
os.Exit(0)
|
||||
@@ -1198,7 +1219,7 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return loaded.runner.Encode(ctx, s)
|
||||
return loaded.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
@@ -1328,9 +1349,8 @@ func ChatHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
@@ -1354,14 +1374,12 @@ func ChatHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
}, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -31,13 +31,22 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
|
||||
createTestFile := func(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), name)
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.Write([]byte("GGUF"))
|
||||
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
||||
assert.Nil(t, err)
|
||||
_, err = f.Write([]byte{0x2, 0})
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
assert.Nil(t, err)
|
||||
|
||||
return f.Name()
|
||||
@@ -52,7 +61,7 @@ func Test_Routes(t *testing.T) {
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
t.Logf("Status: %s", resp.Status)
|
||||
}
|
||||
err = CreateModel(context.TODO(), name, "", commands, fn)
|
||||
err = CreateModel(context.TODO(), name, "", "", commands, fn)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -201,7 +210,7 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{}
|
||||
s := &Server{}
|
||||
router := s.GenerateRoutes()
|
||||
|
||||
httpSrv := httptest.NewServer(router)
|
||||
@@ -232,27 +241,3 @@ func Test_Routes(t *testing.T) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
83
types/model/digest.go
Normal file
83
types/model/digest.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest represents a digest of a model Manifest. It is a comparable value
|
||||
// type and is immutable.
|
||||
//
|
||||
// The zero Digest is not a valid digest.
|
||||
type Digest struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// Type returns the digest type of the digest.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseDigest("sha256-1234").Type() // returns "sha256"
|
||||
func (d Digest) Type() string {
|
||||
typ, _, _ := strings.Cut(d.s, "-")
|
||||
return typ
|
||||
}
|
||||
|
||||
// String returns the digest in the form of "<digest-type>-<digest>", or the
|
||||
// empty string if the digest is invalid.
|
||||
func (d Digest) String() string { return d.s }
|
||||
|
||||
// IsValid returns true if the digest is valid (not zero).
|
||||
//
|
||||
// A valid digest may be created only by ParseDigest, or
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
46
types/model/digest_test.go
Normal file
46
types/model/digest_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {s: "sha256-1234"},
|
||||
"sha256-5678": {s: "sha256-5678"},
|
||||
"blake2-9abc": {s: "blake2-9abc"},
|
||||
"-1234": {},
|
||||
"sha256-": {},
|
||||
"sha256-1234-5678": {},
|
||||
"sha256-P": {}, // invalid hex
|
||||
"sha256-1234P": {},
|
||||
"---": {},
|
||||
}
|
||||
|
||||
func TestDigestParse(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, want := range testDigests {
|
||||
got := ParseDigest(s)
|
||||
t.Logf("ParseDigest(%q) = %#v", s, got)
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestString(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, d := range testDigests {
|
||||
want := s
|
||||
if !d.IsValid() {
|
||||
want = ""
|
||||
}
|
||||
got := d.String()
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
got = ParseDigest(s).String()
|
||||
if got != want {
|
||||
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
572
types/model/name.go
Normal file
572
types/model/name.go
Normal file
@@ -0,0 +1,572 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/types/structs"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
|
||||
// used by this package, but are exported so that other packages can
|
||||
// use them, instead of defining their own errors for them.
|
||||
ErrInvalidName = errors.New("invalid model name")
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
// DefaultMask is the default mask used by [Name.DisplayShortest].
|
||||
DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
|
||||
// DefaultFill is the default fill used by [ParseName].
|
||||
DefaultFill = "registry.ollama.ai/library/_:latest"
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
|
||||
type PartKind int
|
||||
|
||||
// Levels of concreteness
|
||||
const (
|
||||
// Each value aligns with its index in the Name.parts array.
|
||||
|
||||
PartHost PartKind = iota
|
||||
PartNamespace
|
||||
PartModel
|
||||
PartTag
|
||||
PartBuild
|
||||
PartDigest
|
||||
|
||||
// Invalid is a special part that is used to indicate that a part is
|
||||
// invalid. It is not a valid part of a Name.
|
||||
//
|
||||
// It should be kept as the last part in the list.
|
||||
PartInvalid
|
||||
)
|
||||
|
||||
var kindNames = map[PartKind]string{
|
||||
PartHost: "Host",
|
||||
PartNamespace: "Namespace",
|
||||
PartModel: "Name",
|
||||
PartTag: "Tag",
|
||||
PartBuild: "Build",
|
||||
PartDigest: "Digest",
|
||||
PartInvalid: "Invalid",
|
||||
}
|
||||
|
||||
func (k PartKind) String() string {
|
||||
return cmp.Or(kindNames[k], "Unknown")
|
||||
}
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
// with the case preserved, but is not directly comparable with other Names
|
||||
// since model names can be represented with different casing depending on
|
||||
// the use case. For instance, "Mistral" and "mistral" are the same model
|
||||
// but each version may have come from different sources (e.g. copied from a
|
||||
// Web page, or from a file path).
|
||||
//
|
||||
// Valid Names can ONLY be constructed by calling [ParseName].
|
||||
//
|
||||
// A Name is valid if and only if is have a valid Model part. The other parts
|
||||
// are optional.
|
||||
//
|
||||
// A Name is considered "complete" if it has all parts present. To check if a
|
||||
// Name is complete, use [Name.IsComplete].
|
||||
//
|
||||
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
|
||||
//
|
||||
// The parts of a Name are:
|
||||
//
|
||||
// - Host: the domain of the model (optional)
|
||||
// - Namespace: the namespace of the model (optional)
|
||||
// - Model: the name of the model (required)
|
||||
// - Tag: the tag of the model (optional)
|
||||
// - Build: the build of the model; usually the quantization or "file type" (optional)
|
||||
//
|
||||
// The parts can be obtained in their original form by calling [Name.Parts].
|
||||
//
|
||||
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
|
||||
//
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [6]string // host, namespace, model, tag, build, digest
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offsets all into a single uint64 since the first
|
||||
// parts take less bits since their max offset is less than the max
|
||||
// offset of the next part. This would save a ton of bytes per Name
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseNameFill parses s into a Name, and returns the result of filling it with
|
||||
// defaults. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
|
||||
//
|
||||
// The name part is required, all others are optional. If a part is missing,
|
||||
// it is left empty in the returned Name. If a part is invalid, the zero Ref
|
||||
// value is returned.
|
||||
//
|
||||
// The build part is normalized to uppercase.
|
||||
//
|
||||
// Examples of valid paths:
|
||||
//
|
||||
// "example.com/library/mistral:7b+x"
|
||||
// "example.com/eva/mistral:7b+Q4_0"
|
||||
// "mistral:7b+x"
|
||||
// "example.com/mike/mistral:latest+Q4_0"
|
||||
// "example.com/bruce/mistral:latest"
|
||||
// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
|
||||
//
|
||||
// Examples of invalid paths:
|
||||
//
|
||||
// "example.com/mistral:7b+"
|
||||
// "example.com/mistral:7b+Q4_0+"
|
||||
// "x/y/z/z:8n+I"
|
||||
// ""
|
||||
//
|
||||
// It returns the zero value if any part is invalid.
|
||||
//
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
//
|
||||
// For more about filling in missing parts, see [Fill].
|
||||
func ParseNameFill(s, defaults string) Name {
|
||||
var r Name
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if kind == PartInvalid {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
r.parts[kind] = part
|
||||
return true
|
||||
})
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
if defaults == "" {
|
||||
return r
|
||||
}
|
||||
return Fill(r, ParseNameFill(defaults, ""))
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
// ParseName is equal to ParseNameFill(s, DefaultFill).
|
||||
func ParseName(s string) Name {
|
||||
return ParseNameFill(s, DefaultFill)
|
||||
}
|
||||
|
||||
func MustParseNameFill(s, defaults string) Name {
|
||||
r := ParseNameFill(s, "")
|
||||
if !r.IsValid() {
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Fill fills in the missing parts of dst with the parts of src.
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
func Fill(dst, src Name) Name {
|
||||
var r Name
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithBuild returns a copy of r with the build set to the given string.
|
||||
func (r Name) WithBuild(build string) Name {
|
||||
r.parts[PartBuild] = build
|
||||
return r
|
||||
}
|
||||
|
||||
func (r Name) WithDigest(digest Digest) Name {
|
||||
r.parts[PartDigest] = digest.String()
|
||||
return r
|
||||
}
|
||||
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convenient way to compare names, use [Name.EqualFold].
|
||||
//
|
||||
//nolint:errcheck
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.Parts() {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c = c - 'A' + 'a'
|
||||
}
|
||||
h.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func (r Name) slice(from, to PartKind) Name {
|
||||
var v Name
|
||||
copy(v.parts[from:to+1], r.parts[from:to+1])
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayShortest returns the shortest possible display string in form:
|
||||
//
|
||||
// [host/][<namespace>/]<model>[:<tag>]
|
||||
//
|
||||
// The host is omitted if it is the mask host is the same as r.
|
||||
// The namespace is omitted if the host and the namespace are the same as r.
|
||||
// The tag is omitted if it is the mask tag is the same as r.
|
||||
func (r Name) DisplayShortest(mask string) string {
|
||||
mask = cmp.Or(mask, DefaultMask)
|
||||
d := ParseName(mask)
|
||||
if !d.IsValid() {
|
||||
panic("mask is an invalid Name")
|
||||
}
|
||||
equalSlice := func(form, to PartKind) bool {
|
||||
return r.slice(form, to).EqualFold(d.slice(form, to))
|
||||
}
|
||||
if equalSlice(PartHost, PartNamespace) {
|
||||
r.parts[PartNamespace] = ""
|
||||
}
|
||||
if equalSlice(PartHost, PartHost) {
|
||||
r.parts[PartHost] = ""
|
||||
}
|
||||
if equalSlice(PartTag, PartTag) {
|
||||
r.parts[PartTag] = ""
|
||||
}
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
PartHost: "/",
|
||||
PartNamespace: "/",
|
||||
PartModel: ":",
|
||||
PartTag: "+",
|
||||
PartBuild: "@",
|
||||
PartDigest: "",
|
||||
}
|
||||
|
||||
// WriteTo implements io.WriterTo. It writes the fullest possible display
|
||||
// string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
|
||||
//
|
||||
// Missing parts and their separators are not written.
|
||||
//
|
||||
// The full digest is always prefixed with "@". That is if [Name.IsValid]
|
||||
// reports false and [Name.IsResolved] reports true, then the string is
|
||||
// returned as "@<digest-type>-<digest>".
|
||||
func (r Name) writeTo(w io.StringWriter) error {
|
||||
var partsWritten int
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if partsWritten > 0 || i == int(PartDigest) {
|
||||
if _, err := w.WriteString(seps[i-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := w.WriteString(r.parts[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
partsWritten++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
}
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
_ = r.writeTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer. It returns a string suitable for
|
||||
// debugging and logging. It is similar to [Name.String] but it always
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return r.String()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
// IsComplete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) IsComplete() bool {
|
||||
return !slices.Contains(r.parts[:PartDigest], "")
|
||||
}
|
||||
|
||||
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
|
||||
// build part to be present.
|
||||
func (r Name) IsCompleteNoBuild() bool {
|
||||
return !slices.Contains(r.parts[:PartBuild], "")
|
||||
}
|
||||
|
||||
// IsResolved reports true if the Name has a valid digest.
|
||||
//
|
||||
// It is possible to have a valid Name, or a complete Name that is not
|
||||
// resolved.
|
||||
func (r Name) IsResolved() bool {
|
||||
return r.Digest().IsValid()
|
||||
}
|
||||
|
||||
// Digest returns the digest part of the Name, if any.
|
||||
//
|
||||
// If Digest returns a non-empty string, then [Name.IsResolved] will return
|
||||
// true, and digest is considered valid.
|
||||
func (r Name) Digest() Digest {
|
||||
// This was already validated by ParseName, so we can just return it.
|
||||
return Digest{r.parts[PartDigest]}
|
||||
}
|
||||
|
||||
// EqualFold reports whether r and o are equivalent model names, ignoring
|
||||
// case.
|
||||
func (r Name) EqualFold(o Name) bool {
|
||||
return r.CompareFold(o) == 0
|
||||
}
|
||||
|
||||
// CompareFold performs a case-insensitive cmp.Compare on r and o.
|
||||
//
|
||||
// This can be used with [slices.SortFunc].
|
||||
//
|
||||
// For simple equality checks, use [Name.EqualFold].
|
||||
func (r Name) CompareFold(o Name) int {
|
||||
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
|
||||
}
|
||||
|
||||
func compareFold(a, b string) int {
|
||||
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
|
||||
return cmp.Compare(downcase(a), downcase(b))
|
||||
})
|
||||
}
|
||||
|
||||
func downcase(r rune) rune {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return r - 'A' + 'a'
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
|
||||
|
||||
// Parts returns the parts of the Name in order of concreteness.
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
|
||||
// restrictions in the go1.22 iter package requiring the
|
||||
// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
|
||||
// which we are not yet ready to support.
|
||||
//
|
||||
// Once we are ready to support rangefunc, this can be removed and replaced
|
||||
// with the iter.Seq2 type.
|
||||
type iter_Seq2[A, B any] func(func(A, B) bool)
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
// to least specific.
|
||||
//
|
||||
// It normalizes the input string by removing "http://" and "https://" only.
|
||||
// No other normalizations are performed.
|
||||
func parts(s string) iter_Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
|
||||
if len(s) > MaxNamePartLen || len(s) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
yieldValid := func(kind PartKind, part string) bool {
|
||||
if !isValidPart(kind, part) {
|
||||
yield(PartInvalid, "")
|
||||
return false
|
||||
}
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
numConsecutiveDots := 0
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if partLen++; partLen > MaxNamePartLen {
|
||||
// catch a part that is too long early, so
|
||||
// we don't keep spinning on it, waiting for
|
||||
// an isInValidPart check which would scan
|
||||
// over it again.
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch s[i] {
|
||||
case '@':
|
||||
switch state {
|
||||
case PartDigest:
|
||||
if !yieldValid(PartDigest, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
if i == 0 {
|
||||
// This is the form
|
||||
// "@<digest>" which is valid.
|
||||
//
|
||||
// We're done.
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartBuild, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '+':
|
||||
switch state {
|
||||
case PartBuild, PartDigest:
|
||||
if !yieldValid(PartBuild, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartTag, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartTag, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartModel, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case PartModel, PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartModel, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j = PartNamespace, i
|
||||
case PartNamespace:
|
||||
if !yieldValid(PartNamespace, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartHost, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
default:
|
||||
if s[i] == '.' {
|
||||
if numConsecutiveDots++; numConsecutiveDots > 1 {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
numConsecutiveDots = 0
|
||||
}
|
||||
if !isValidByteFor(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state <= PartNamespace {
|
||||
yieldValid(state, s[:j])
|
||||
} else {
|
||||
yieldValid(PartModel, s[:j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r Name) IsZero() bool {
|
||||
return r.parts == [6]string{}
|
||||
}
|
||||
|
||||
// IsValid reports if a model has at minimum a valid model part.
|
||||
func (r Name) IsValid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// isValidPart reports if s contains all valid characters for the given
|
||||
// part kind.
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range []byte(s) {
|
||||
if !isValidByteFor(kind, c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidByteFor(kind PartKind, c byte) bool {
|
||||
if kind == PartNamespace && c == '.' {
|
||||
return false
|
||||
}
|
||||
if c == '.' || c == '-' {
|
||||
return true
|
||||
}
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
490
types/model/name_test.go
Normal file
490
types/model/name_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
digest string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[PartHost],
|
||||
namespace: p.parts[PartNamespace],
|
||||
model: p.parts[PartModel],
|
||||
tag: p.parts[PartTag],
|
||||
build: p.parts[PartBuild],
|
||||
digest: p.parts[PartDigest],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
"mistral:7b": {model: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {model: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// invalid digest
|
||||
"mistral:latest@invalid256-": {},
|
||||
"mistral:latest@-123": {},
|
||||
"mistral:latest@!-123": {},
|
||||
"mistral:latest@1-!": {},
|
||||
"mistral:latest@": {},
|
||||
|
||||
// resolved
|
||||
"x@sha123-1": {model: "x", digest: "sha123-1"},
|
||||
"@sha456-2": {digest: "sha456-2"},
|
||||
|
||||
"@@sha123-1": {},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
|
||||
// invalid (includes fuzzing trophies)
|
||||
" / / : + ": {},
|
||||
" / : + ": {},
|
||||
" : + ": {},
|
||||
" + ": {},
|
||||
" : ": {},
|
||||
" / ": {},
|
||||
" /": {},
|
||||
"/ ": {},
|
||||
"/": {},
|
||||
":": {},
|
||||
"+": {},
|
||||
|
||||
// (".") in namepsace is not allowed
|
||||
"invalid.com/7b+x": {},
|
||||
|
||||
"invalid:7b+Q4_0:latest": {},
|
||||
"in valid": {},
|
||||
"invalid/y/z/foo": {},
|
||||
"/0": {},
|
||||
"0 /0": {},
|
||||
"0 /": {},
|
||||
"0/": {},
|
||||
":/0": {},
|
||||
"+0/00000": {},
|
||||
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
|
||||
"0//0": {},
|
||||
"m+^^^": {},
|
||||
"file:///etc/passwd": {},
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
":x": {},
|
||||
"+x": {},
|
||||
"x+": {},
|
||||
|
||||
// Disallow ("\.+") in any part to prevent path traversal anywhere
|
||||
// we convert the name to a path.
|
||||
"../etc/passwd": {},
|
||||
".../etc/passwd": {},
|
||||
"./../passwd": {},
|
||||
"./0+..": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
// TestConsecutiveDots tests that consecutive dots are not allowed in any
|
||||
// part, to avoid path traversal. There also are some tests in testNames, but
|
||||
// this test is more exhaustive and exists to emphasize the importance of
|
||||
// preventing path traversal.
|
||||
func TestNameConsecutiveDots(t *testing.T) {
|
||||
for i := 1; i < 10; i++ {
|
||||
s := strings.Repeat(".", i)
|
||||
if i > 1 {
|
||||
if g := ParseNameFill(s, "").String(); g != "" {
|
||||
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
|
||||
}
|
||||
} else {
|
||||
if g := ParseNameFill(s, "").String(); g != s {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := PartKind(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
if g := kind.String(); g != name {
|
||||
t.Errorf("%s = %q; want %q", kind, g, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
name := ParseNameFill(s, "")
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseNameFill(name.String(), "").EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
complete bool
|
||||
completeNoBuild bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"incomplete/mistral:7b+x", false, false},
|
||||
{"incomplete/mistral:7b+Q4_0", false, false},
|
||||
{"incomplete:7b+x", false, false},
|
||||
{"complete.com/x/mistral:latest+Q4_0", true, true},
|
||||
{"complete.com/x/mistral:latest", false, true},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseNameFill(tt.in, "")
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
}
|
||||
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
|
||||
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Complete uses Parts which returns a slice, but it should be
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameLogValue(t *testing.T) {
|
||||
cases := []string{
|
||||
"example.com/library/mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
"mistral:7b+Q4_0",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseNameFill(s, "")
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("expected log output to contain %q; got %q", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameGoString(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantString string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseNameFill(tt.in, "")
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayShortest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
mask string
|
||||
want string
|
||||
wantPanic bool
|
||||
}{
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
|
||||
// case-insensitive
|
||||
{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
|
||||
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
|
||||
// invalid mask
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
|
||||
|
||||
// DefaultMask
|
||||
{"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
|
||||
|
||||
// Auto-Fill
|
||||
{"x", "example.com/library/_:latest", "x", false},
|
||||
{"x", "example.com/library/_:latest+Q4_0", "x", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
defer func() {
|
||||
if tt.wantPanic {
|
||||
if recover() == nil {
|
||||
t.Errorf("expected panic")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
p := ParseNameFill(tt.in, "")
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.DisplayShortest(tt.mask); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Add(":x")
|
||||
f.Add("@sha256-123456")
|
||||
f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
|
||||
f.Add(":@!@")
|
||||
f.Add("...")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseNameFill(s, "")
|
||||
|
||||
if strings.Contains(s, "..") && !r0.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
}
|
||||
|
||||
if !r0.IsValid() && !r0.IsResolved() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.Parts() {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.String(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
}
|
||||
|
||||
r1 := ParseNameFill(r0.String(), "")
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
dst string
|
||||
src string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
|
||||
r := Fill(ParseNameFill("mistral", ""), defaults)
|
||||
fmt.Println(r)
|
||||
|
||||
// Output:
|
||||
// registry.ollama.com/library/mistral:latest+Q4_0
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
|
||||
m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
|
||||
m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseNameFill("mistral:LATest", "").MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseNameFill("mistral:latest", ""),
|
||||
ParseNameFill("mistRal:7b+q4", ""),
|
||||
ParseNameFill("MIstral:7b", ""),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// MIstral:7b
|
||||
// mistRal:7b+q4
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func ExampleName_completeAndResolved() {
|
||||
for _, s := range []string{
|
||||
"x/y/z:latest+q4_0@sha123-1",
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
name := ParseNameFill(s, "")
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// complete:true resolved:true digest:sha123-1
|
||||
// complete:true resolved:false digest:
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayShortest() {
|
||||
name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
|
||||
|
||||
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:_"))
|
||||
fmt.Println(name.DisplayShortest("_/_/_:_"))
|
||||
|
||||
// Default
|
||||
name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
|
||||
fmt.Println(name.DisplayShortest(""))
|
||||
|
||||
// Output:
|
||||
// mistral
|
||||
// jmorganca/mistral
|
||||
// jmorganca/mistral:latest
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("/0")
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0//0")
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0 /0")
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("+0/00000")
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string(":")
|
||||
2
types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
vendored
Normal file
2
types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")
|
||||
15
types/structs/structs.go
Normal file
15
types/structs/structs.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
||||
Reference in New Issue
Block a user