mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
89 Commits
brucemacd/
...
progress-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcfbb06f1b | ||
|
|
e8d35d0de0 | ||
|
|
e13e7c8d94 | ||
|
|
78f403ff45 | ||
|
|
08a299e1d0 | ||
|
|
7b5d916a9a | ||
|
|
33ad61b112 | ||
|
|
716e365615 | ||
|
|
3b4424ff98 | ||
|
|
f9c7ead160 | ||
|
|
5930aaeb1a | ||
|
|
faf67db089 | ||
|
|
0667baddc6 | ||
|
|
d006e1e09b | ||
|
|
df2680b4b9 | ||
|
|
010313bb63 | ||
|
|
5296f487a8 | ||
|
|
f05774b04c | ||
|
|
6600bd7d91 | ||
|
|
ed443a0393 | ||
|
|
6945617af5 | ||
|
|
7916f55009 | ||
|
|
d650ad398f | ||
|
|
d223f3b697 | ||
|
|
60830695c2 | ||
|
|
01d9a46854 | ||
|
|
d773b7d671 | ||
|
|
4d4463b2bd | ||
|
|
0e38297f87 | ||
|
|
7e13f568dc | ||
|
|
58245413f4 | ||
|
|
8cf16063a5 | ||
|
|
3a4449e2f1 | ||
|
|
10d59d5f90 | ||
|
|
a4f69a0191 | ||
|
|
82658c3eec | ||
|
|
378d6e1e6a | ||
|
|
afa55bc70c | ||
|
|
49df03da9a | ||
|
|
0189bdd0b7 | ||
|
|
f4711da7bd | ||
|
|
38117fba83 | ||
|
|
1f766c36fb | ||
|
|
484a99e428 | ||
|
|
ec6121c331 | ||
|
|
b86c0a1500 | ||
|
|
7e402ebb8c | ||
|
|
b901a712c6 | ||
|
|
abb8dd57f8 | ||
|
|
a400df48c0 | ||
|
|
6ab4ba4c26 | ||
|
|
e8d4eb3e68 | ||
|
|
ae7e368f75 | ||
|
|
31acd1ebf9 | ||
|
|
9a4757ae66 | ||
|
|
7814019708 | ||
|
|
b698f9a0d8 | ||
|
|
32285a6d19 | ||
|
|
1c198977ec | ||
|
|
330b6c50b0 | ||
|
|
928911bc68 | ||
|
|
5b446cc815 | ||
|
|
451c1596af | ||
|
|
932bded12f | ||
|
|
070ad913ac | ||
|
|
8d8b9f83ae | ||
|
|
f00d359a67 | ||
|
|
291def6adb | ||
|
|
cd3fbf1c49 | ||
|
|
c852b8e021 | ||
|
|
d8932c55e7 | ||
|
|
63f0269f7f | ||
|
|
4759ecae19 | ||
|
|
65b7ecac7b | ||
|
|
f9d2d89135 | ||
|
|
669dc31cf3 | ||
|
|
d4d338c224 | ||
|
|
bfdeffc375 | ||
|
|
e806184023 | ||
|
|
50566113ac | ||
|
|
ad22ace439 | ||
|
|
f4321a421c | ||
|
|
475333d533 | ||
|
|
39fd89308c | ||
|
|
548a9f56a6 | ||
|
|
3f0cb36bdb | ||
|
|
bea1f1fac6 | ||
|
|
5d75d837ef | ||
|
|
711648c9bb |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -15,6 +15,10 @@ ml/backend/**/*.cu linguist-vendored
|
||||
ml/backend/**/*.cuh linguist-vendored
|
||||
ml/backend/**/*.m linguist-vendored
|
||||
ml/backend/**/*.metal linguist-vendored
|
||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||
|
||||
llama/build-info.cpp linguist-generated
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||
|
||||
* text=auto
|
||||
*.go text eol=lf
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -9,6 +9,14 @@ body:
|
||||
description: What happened? What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
|
||||
206
.github/workflows/release.yaml
vendored
206
.github/workflows/release.yaml
vendored
@@ -5,6 +5,10 @@ on:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
env:
|
||||
CGO_CFLAGS: '-O3'
|
||||
CGO_CXXFLAGS: '-O3'
|
||||
|
||||
jobs:
|
||||
setup-environment:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -77,7 +81,7 @@ jobs:
|
||||
path: dist/darwin-arm64
|
||||
- run: |
|
||||
export VERSION=${GITHUB_REF_NAME#v}
|
||||
./scripts/build_darwin.sh macapp sign
|
||||
./scripts/build_darwin.sh sign macapp
|
||||
env:
|
||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
@@ -193,33 +197,38 @@ jobs:
|
||||
env:
|
||||
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
steps:
|
||||
- name: Install system dependencies
|
||||
- name: Install AMD64 system dependencies
|
||||
if: matrix.arch == 'amd64'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ matrix.arch }}" -eq 'amd64') {
|
||||
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
||||
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
} elseif ("${{ matrix.arch }}" -eq 'arm64') {
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force
|
||||
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072
|
||||
iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
||||
echo "C:\ProgramData\chocolatey\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
||||
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- name: Install ARM64 system dependencies
|
||||
if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force
|
||||
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072
|
||||
iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
||||
echo "C:\ProgramData\chocolatey\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
choco install -y --no-progress git gzip
|
||||
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
choco install -y --no-progress git gzip
|
||||
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
|
||||
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
}
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
|
||||
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
||||
- run: |
|
||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
||||
& .\scripts\build_windows.ps1 buildApp
|
||||
@@ -233,7 +242,7 @@ jobs:
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||
|
||||
windows-sign:
|
||||
runs-on: windows
|
||||
runs-on: windows-2022
|
||||
environment: release
|
||||
needs: [windows-depends, windows-build]
|
||||
steps:
|
||||
@@ -254,16 +263,18 @@ jobs:
|
||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-windows-*
|
||||
pattern: build-windows-*
|
||||
path: dist\
|
||||
merge-multiple: true
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: depends-windows-amd64-*
|
||||
pattern: depends-windows-amd64-*
|
||||
path: dist\windows-amd64\
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
||||
env:
|
||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
@@ -277,10 +288,13 @@ jobs:
|
||||
include:
|
||||
- os: linux
|
||||
arch: amd64
|
||||
targets: 'archive rocm'
|
||||
target: archive
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: rocm
|
||||
- os: linux
|
||||
arch: arm64
|
||||
targets: archive
|
||||
target: archive
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
@@ -289,38 +303,106 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
- uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
target: ${{ matrix.target }}
|
||||
build-args: |
|
||||
GOFLAGS=${{ env.GOFLAGS }}
|
||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
apt-get update && apt-get install pigz
|
||||
for TARGET in ${{ matrix.targets }}; do docker buildx build --platform $PLATFORM --target $TARGET --output type=local,dest=dist/$PLATFORM .; done
|
||||
tar c -C dist/$PLATFORM . | pigz -9cv >dist/ollama-${PLATFORM//\//-}.tgz
|
||||
env:
|
||||
PLATFORM: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
esac
|
||||
done
|
||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-${{ matrix.os }}-${{ matrix.arch }}
|
||||
name: dist-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
dist/ollama-${{ matrix.os }}-${{ matrix.arch }}.tgz
|
||||
*.tgz
|
||||
|
||||
docker-build:
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- flavor: 'latest=false'
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- os: linux
|
||||
arch: arm64
|
||||
build-args: |
|
||||
GOFLAGS=${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
- flavor: 'latest=false,suffix=rocm'
|
||||
platforms: linux/amd64
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
- os: linux
|
||||
arch: amd64
|
||||
build-args: |
|
||||
GOFLAGS=${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
- os: linux
|
||||
arch: amd64
|
||||
suffix: '-rocm'
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
FLAVOR=rocm
|
||||
runs-on: linux
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
env:
|
||||
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-qemu-action@v2
|
||||
- uses: docker/setup-buildx-action@v2
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
|
||||
- id: build-push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||
echo "${{ steps.build-push.outputs.digest }}" >${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}.txt
|
||||
working-directory: ${{ runner.temp }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digest-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}
|
||||
path: |
|
||||
${{ runner.temp }}/${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}.txt
|
||||
|
||||
# Merge Docker images for the same flavor into a single multi-arch manifest
|
||||
docker-merge-push:
|
||||
strategy:
|
||||
matrix:
|
||||
suffix: ['', '-rocm']
|
||||
runs-on: linux
|
||||
environment: release
|
||||
needs: [docker-build-push]
|
||||
steps:
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKER_USER }}
|
||||
@@ -328,22 +410,23 @@ jobs:
|
||||
- id: metadata
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
flavor: ${{ matrix.flavor }}
|
||||
flavor: |
|
||||
latest=false
|
||||
suffix=${{ matrix.suffix }}
|
||||
images: |
|
||||
ollama/ollama
|
||||
tags: |
|
||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
- uses: docker/build-push-action@v6
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-to: type=inline
|
||||
provenance: false
|
||||
pattern: digest-*
|
||||
path: ${{ runner.temp }}
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
@@ -356,33 +439,24 @@ jobs:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set Version
|
||||
shell: bash
|
||||
run: |
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: dist
|
||||
pattern: dist-darwin
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist
|
||||
pattern: dist-windows
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist
|
||||
pattern: dist-linux-*
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist
|
||||
pattern: dist-windows
|
||||
- run: |
|
||||
ls -lh dist/
|
||||
(cd dist; find . -type f | xargs sha256sum > ../sha256sum.txt)
|
||||
mv sha256sum.txt dist/
|
||||
cat dist/sha256sum.txt
|
||||
merge-multiple: true
|
||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
- name: Create or update Release
|
||||
run: |
|
||||
RELEASE_VERSION=$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
|
||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||
|
||||
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
@@ -163,5 +163,5 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Verify patches apply cleanly and do not change files
|
||||
run: |
|
||||
make -f Makefile.sync clean checkout sync
|
||||
make -f Makefile.sync clean sync
|
||||
git diff --compact-summary --exit-code
|
||||
|
||||
@@ -24,11 +24,16 @@ set(GGML_LLAMAFILE ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
|
||||
if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
||||
|
||||
@@ -80,6 +85,11 @@ if(CMAKE_CUDA_COMPILER)
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
if(CMAKE_HIP_COMPILER)
|
||||
set(HIP_PLATFORM "amd")
|
||||
@@ -87,15 +97,22 @@ if(CMAKE_HIP_COMPILER)
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
if(AMDGPU_TARGETS)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES amdhip64 hipblas rocblas amd_comgr hsa_runtime64 rocprofiler-register drm_amdgpu drm numa
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
POST_EXCLUDE_REGEXES "system32"
|
||||
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -15,7 +15,11 @@ help:
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||
|
||||
.PHONY: sync
|
||||
sync: llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
||||
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
||||
|
||||
.PHONY: llama/build-info.cpp
|
||||
llama/build-info.cpp: llama/build-info.cpp.in
|
||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
||||
|
||||
.PHONY: llama/llama.cpp
|
||||
llama/llama.cpp: llama/vendor/ apply-patches
|
||||
|
||||
64
README.md
64
README.md
@@ -18,7 +18,7 @@ Get up and running with large language models.
|
||||
|
||||
### Linux
|
||||
|
||||
```
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
```
|
||||
|
||||
@@ -42,7 +42,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama run llama3.2
|
||||
```
|
||||
|
||||
@@ -54,6 +54,8 @@ Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||
@@ -92,13 +94,13 @@ Ollama supports importing GGUF models in the Modelfile:
|
||||
|
||||
2. Create the model in Ollama
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
3. Run the model
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama run example
|
||||
```
|
||||
|
||||
@@ -110,7 +112,7 @@ See the [guide](docs/import.md) on importing models for more information.
|
||||
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama pull llama3.2
|
||||
```
|
||||
|
||||
@@ -145,13 +147,13 @@ For more information on working with a Modelfile, see the [Modelfile](docs/model
|
||||
|
||||
`ollama create` is used to create a model from a Modelfile.
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama create mymodel -f ./Modelfile
|
||||
```
|
||||
|
||||
### Pull a model
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama pull llama3.2
|
||||
```
|
||||
|
||||
@@ -159,13 +161,13 @@ ollama pull llama3.2
|
||||
|
||||
### Remove a model
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama rm llama3.2
|
||||
```
|
||||
|
||||
### Copy a model
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama cp llama3.2 my-model
|
||||
```
|
||||
|
||||
@@ -184,37 +186,39 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||
|
||||
```
|
||||
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||
```
|
||||
|
||||
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||
|
||||
### Pass the prompt as an argument
|
||||
|
||||
```shell
|
||||
ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
||||
```
|
||||
$ ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
```
|
||||
|
||||
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
|
||||
### Show model information
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama show llama3.2
|
||||
```
|
||||
|
||||
### List models on your computer
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama list
|
||||
```
|
||||
|
||||
### List which models are currently loaded
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama ps
|
||||
```
|
||||
|
||||
### Stop a model which is currently running
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama stop llama3.2
|
||||
```
|
||||
|
||||
@@ -230,13 +234,13 @@ See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/develo
|
||||
|
||||
Next, start the server:
|
||||
|
||||
```
|
||||
```shell
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
Finally, in a separate shell, run a model:
|
||||
|
||||
```
|
||||
```shell
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
@@ -246,7 +250,7 @@ Ollama has a REST API for running and managing models.
|
||||
|
||||
### Generate a response
|
||||
|
||||
```
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama3.2",
|
||||
"prompt":"Why is the sky blue?"
|
||||
@@ -255,7 +259,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
### Chat with a model
|
||||
|
||||
```
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
@@ -353,6 +357,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
|
||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
@@ -369,7 +374,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -427,9 +439,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||
|
||||
### Libraries
|
||||
@@ -483,6 +496,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -533,13 +548,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
@@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
@@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
Run the examples in this directory with:
|
||||
|
||||
```
|
||||
```shell
|
||||
go run example_name/main.go
|
||||
```
|
||||
|
||||
## Chat - Chat with a model
|
||||
- [chat/main.go](chat/main.go)
|
||||
|
||||
|
||||
@@ -17,6 +17,6 @@ If you want to build the installer, youll need to install
|
||||
In the top directory of this repo, run the following powershell script
|
||||
to build the ollama CLI, ollama app, and ollama installer.
|
||||
|
||||
```
|
||||
```powershell
|
||||
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
|
||||
```
|
||||
|
||||
63
cache/cache.go
vendored
63
cache/cache.go
vendored
@@ -1,63 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Position int
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Sub(i int) Cache
|
||||
Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
|
||||
}
|
||||
|
||||
type Simple struct {
|
||||
DType ml.DType
|
||||
Capacity int
|
||||
|
||||
keys, values []ml.Tensor
|
||||
}
|
||||
|
||||
func (c *Simple) Sub(i int) Cache {
|
||||
if i >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
return &Simple{
|
||||
keys: c.keys[i : i+1],
|
||||
values: c.values[i : i+1],
|
||||
Capacity: c.Capacity,
|
||||
DType: c.DType,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
|
||||
if c.keys[0] == nil || c.values[0] == nil {
|
||||
c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
|
||||
c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
|
||||
|
||||
n := min(c.Capacity, int(key.Dim(2))+opts.Position)
|
||||
|
||||
key = c.keys[0].View(ctx, 0,
|
||||
int(key.Dim(0)), int(key.Stride(1)),
|
||||
int(key.Dim(1)), int(key.Stride(2)),
|
||||
n,
|
||||
)
|
||||
|
||||
value = c.values[0].View(ctx, 0,
|
||||
int(value.Dim(0)), int(value.Stride(1)),
|
||||
int(value.Dim(1)), int(value.Stride(2)),
|
||||
n,
|
||||
)
|
||||
|
||||
// TODO shift context if necessary
|
||||
|
||||
return key, value
|
||||
}
|
||||
45
cmd/cmd.go
45
cmd/cmd.go
@@ -15,13 +15,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/console"
|
||||
@@ -35,9 +33,9 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llama/runner"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
}
|
||||
return info, err
|
||||
@@ -338,7 +337,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = len(info.ProjectorInfo) != 0
|
||||
// TODO(jessegross): We should either find another way to know if this is
|
||||
// a vision model or remove the logic. Also consider that other modalities will
|
||||
// need different behavior anyways.
|
||||
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
if interactive {
|
||||
@@ -855,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
@@ -900,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
if err := client.Chat(cmd.Context(), req, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -943,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
@@ -989,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1014,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/llama/runner"
|
||||
"github.com/ollama/ollama/runner"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -19,17 +19,18 @@ var LibOllamaPath string = func() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return ""
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
libPath := filepath.Dir(exe)
|
||||
var libPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||
case "linux":
|
||||
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||
case "darwin":
|
||||
libPath = filepath.Dir(exe)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
@@ -37,17 +38,19 @@ var LibOllamaPath string = func() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// build paths for development
|
||||
buildPaths := []string{
|
||||
paths := []string{
|
||||
libPath,
|
||||
|
||||
// build paths for development
|
||||
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||
}
|
||||
|
||||
for _, p := range buildPaths {
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return libPath
|
||||
return filepath.Dir(exe)
|
||||
}()
|
||||
|
||||
45
docs/api.md
45
docs/api.md
@@ -31,7 +31,7 @@ Certain endpoints stream responses as JSON objects. Streaming can be disabled by
|
||||
|
||||
## Generate a completion
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/generate
|
||||
```
|
||||
|
||||
@@ -306,7 +306,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
#### Response
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"model": "llava",
|
||||
"created_at": "2023-11-03T15:36:02.583064Z",
|
||||
@@ -485,7 +485,7 @@ A single JSON object is returned:
|
||||
|
||||
## Generate a chat completion
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/chat
|
||||
```
|
||||
|
||||
@@ -495,14 +495,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
||||
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
|
||||
- `tools`: list of tools in JSON for the model to use if supported
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||
- `content`: the content of the message
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools the model wants to use
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
@@ -795,7 +795,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
##### Request
|
||||
|
||||
```
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
@@ -870,7 +870,7 @@ If the messages array is empty, the model will be loaded into memory.
|
||||
|
||||
##### Request
|
||||
|
||||
```
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": []
|
||||
@@ -878,6 +878,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
@@ -897,7 +898,7 @@ If the messages array is empty and the `keep_alive` parameter is set to `0`, a m
|
||||
|
||||
##### Request
|
||||
|
||||
```
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [],
|
||||
@@ -924,7 +925,7 @@ A single JSON object is returned:
|
||||
|
||||
## Create a Model
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/create
|
||||
```
|
||||
|
||||
@@ -1020,7 +1021,7 @@ curl http://localhost:11434/api/create -d '{
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```
|
||||
```json
|
||||
{"status":"quantizing F16 model to Q4_K_M"}
|
||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
||||
@@ -1051,7 +1052,7 @@ curl http://localhost:11434/api/create -d '{
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```
|
||||
```json
|
||||
{"status":"parsing GGUF"}
|
||||
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
|
||||
{"status":"writing manifest"}
|
||||
@@ -1118,7 +1119,7 @@ Return 200 OK if the blob exists, 404 Not Found if it does not.
|
||||
|
||||
## Push a Blob
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/blobs/:digest
|
||||
```
|
||||
|
||||
@@ -1142,7 +1143,7 @@ Return 201 Created if the blob was successfully created, 400 Bad Request if the
|
||||
|
||||
## List Local Models
|
||||
|
||||
```shell
|
||||
```
|
||||
GET /api/tags
|
||||
```
|
||||
|
||||
@@ -1195,7 +1196,7 @@ A single JSON object will be returned.
|
||||
|
||||
## Show Model Information
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/show
|
||||
```
|
||||
|
||||
@@ -1261,7 +1262,7 @@ curl http://localhost:11434/api/show -d '{
|
||||
|
||||
## Copy a Model
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/copy
|
||||
```
|
||||
|
||||
@@ -1284,7 +1285,7 @@ Returns a 200 OK if successful, or a 404 Not Found if the source model doesn't e
|
||||
|
||||
## Delete a Model
|
||||
|
||||
```shell
|
||||
```
|
||||
DELETE /api/delete
|
||||
```
|
||||
|
||||
@@ -1310,7 +1311,7 @@ Returns a 200 OK if successful, 404 Not Found if the model to be deleted doesn't
|
||||
|
||||
## Pull a Model
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/pull
|
||||
```
|
||||
|
||||
@@ -1382,7 +1383,7 @@ if `stream` is set to false, then the response is a single JSON object:
|
||||
|
||||
## Push a Model
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/push
|
||||
```
|
||||
|
||||
@@ -1447,7 +1448,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
|
||||
|
||||
## Generate Embeddings
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/embed
|
||||
```
|
||||
|
||||
@@ -1515,7 +1516,7 @@ curl http://localhost:11434/api/embed -d '{
|
||||
```
|
||||
|
||||
## List Running Models
|
||||
```shell
|
||||
```
|
||||
GET /api/ps
|
||||
```
|
||||
|
||||
@@ -1562,7 +1563,7 @@ A single JSON object will be returned.
|
||||
|
||||
> Note: this endpoint has been superseded by `/api/embed`
|
||||
|
||||
```shell
|
||||
```
|
||||
POST /api/embeddings
|
||||
```
|
||||
|
||||
@@ -1602,7 +1603,7 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
|
||||
## Version
|
||||
|
||||
```shell
|
||||
```
|
||||
GET /api/version
|
||||
```
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
Install prerequisites:
|
||||
|
||||
- [Go](https://go.dev/doc/install)
|
||||
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://jmeubank.github.io/tdm-gcc/download/) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
|
||||
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://github.com/jmeubank/tdm-gcc/releases/latest) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
|
||||
|
||||
Then build and run Ollama from the root directory of the repository:
|
||||
|
||||
```
|
||||
```shell
|
||||
go run . serve
|
||||
```
|
||||
|
||||
@@ -23,14 +23,14 @@ Install prerequisites:
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
```
|
||||
```shell
|
||||
cmake -B build
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```
|
||||
```shell
|
||||
go run . serve
|
||||
```
|
||||
|
||||
@@ -57,14 +57,14 @@ Install prerequisites:
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
```
|
||||
```shell
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```
|
||||
```shell
|
||||
go run . serve
|
||||
```
|
||||
|
||||
@@ -88,26 +88,26 @@ Install prerequisites:
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
```
|
||||
```shell
|
||||
cmake -B build
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```
|
||||
```shell
|
||||
go run . serve
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
```
|
||||
```shell
|
||||
docker build .
|
||||
```
|
||||
|
||||
### ROCm
|
||||
|
||||
```
|
||||
```shell
|
||||
docker build --build-arg FLAVOR=rocm .
|
||||
```
|
||||
|
||||
@@ -115,6 +115,17 @@ docker build --build-arg FLAVOR=rocm .
|
||||
|
||||
To run tests, use `go test`:
|
||||
|
||||
```
|
||||
```shell
|
||||
go test ./...
|
||||
```
|
||||
|
||||
## Library detection
|
||||
|
||||
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||
|
||||
* `./lib/ollama` (Windows)
|
||||
* `../lib/ollama` (Linux)
|
||||
* `.` (macOS)
|
||||
* `build/lib/ollama` (for development)
|
||||
|
||||
If the libraries are not found, Ollama will not run with any acceleration libraries.
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
### CPU only
|
||||
|
||||
```bash
|
||||
```shell
|
||||
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||
```
|
||||
|
||||
@@ -11,42 +11,46 @@ Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-
|
||||
|
||||
#### Install with Apt
|
||||
1. Configure the repository
|
||||
```bash
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
||||
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
||||
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
|
||||
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
```
|
||||
|
||||
```shell
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
||||
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
||||
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
|
||||
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
```
|
||||
|
||||
2. Install the NVIDIA Container Toolkit packages
|
||||
```bash
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
```
|
||||
|
||||
```shell
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
```
|
||||
|
||||
#### Install with Yum or Dnf
|
||||
1. Configure the repository
|
||||
|
||||
```bash
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
||||
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
||||
```
|
||||
```shell
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
||||
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
||||
```
|
||||
|
||||
2. Install the NVIDIA Container Toolkit packages
|
||||
|
||||
```bash
|
||||
sudo yum install -y nvidia-container-toolkit
|
||||
```
|
||||
```shell
|
||||
sudo yum install -y nvidia-container-toolkit
|
||||
```
|
||||
|
||||
#### Configure Docker to use Nvidia driver
|
||||
```
|
||||
|
||||
```shell
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
|
||||
#### Start the container
|
||||
|
||||
```bash
|
||||
```shell
|
||||
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||
```
|
||||
|
||||
@@ -57,7 +61,7 @@ docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ol
|
||||
|
||||
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
|
||||
|
||||
```
|
||||
```shell
|
||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
||||
```
|
||||
|
||||
@@ -65,7 +69,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114
|
||||
|
||||
Now you can run a model:
|
||||
|
||||
```
|
||||
```shell
|
||||
docker exec -it ollama ollama run llama3.2
|
||||
```
|
||||
|
||||
|
||||
22
docs/faq.md
22
docs/faq.md
@@ -24,7 +24,7 @@ By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
```
|
||||
```shell
|
||||
/set parameter num_ctx 4096
|
||||
```
|
||||
|
||||
@@ -46,10 +46,15 @@ Use the `ollama ps` command to see what models are currently loaded into memory.
|
||||
|
||||
```shell
|
||||
ollama ps
|
||||
NAME ID SIZE PROCESSOR UNTIL
|
||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
|
||||
> **Output**:
|
||||
>
|
||||
> ```
|
||||
> NAME ID SIZE PROCESSOR UNTIL
|
||||
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
> ```
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
* `100% GPU` means the model was loaded entirely into the GPU
|
||||
* `100% CPU` means the model was loaded entirely in system memory
|
||||
@@ -66,7 +71,7 @@ If Ollama is run as a macOS application, environment variables should be set usi
|
||||
1. For each environment variable, call `launchctl setenv`.
|
||||
|
||||
```bash
|
||||
launchctl setenv OLLAMA_HOST "0.0.0.0"
|
||||
launchctl setenv OLLAMA_HOST "0.0.0.0:11434"
|
||||
```
|
||||
|
||||
2. Restart Ollama application.
|
||||
@@ -81,14 +86,14 @@ If Ollama is run as a systemd service, environment variables should be set using
|
||||
|
||||
```ini
|
||||
[Service]
|
||||
Environment="OLLAMA_HOST=0.0.0.0"
|
||||
Environment="OLLAMA_HOST=0.0.0.0:11434"
|
||||
```
|
||||
|
||||
3. Save and exit.
|
||||
|
||||
4. Reload `systemd` and restart Ollama:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
systemctl daemon-reload
|
||||
systemctl restart ollama
|
||||
```
|
||||
@@ -221,16 +226,19 @@ properties.
|
||||
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
|
||||
|
||||
To preload the mistral model using the generate endpoint, use:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
|
||||
```
|
||||
|
||||
To use the chat completions endpoint, use:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
|
||||
```
|
||||
|
||||
To preload a model using the CLI, use the command:
|
||||
|
||||
```shell
|
||||
ollama run llama3.2 ""
|
||||
```
|
||||
@@ -250,11 +258,13 @@ If you're using the API, use the `keep_alive` parameter with the `/api/generate`
|
||||
* '0' which will unload the model immediately after generating a response
|
||||
|
||||
For example, to preload a model and leave it in memory use:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
|
||||
```
|
||||
|
||||
To unload the model and free up memory use:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
|
||||
```
|
||||
|
||||
@@ -7,7 +7,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 9.0 | NVIDIA | `H100` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
||||
|
||||
@@ -20,13 +20,13 @@ Make sure that you use the same base model in the `FROM` command as you used to
|
||||
|
||||
Now run `ollama create` from the directory where the `Modelfile` was created:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
ollama create my-model
|
||||
```
|
||||
|
||||
Lastly, test the model:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
ollama run my-model
|
||||
```
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ sudo systemctl status ollama
|
||||
|
||||
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
|
||||
|
||||
```
|
||||
```shell
|
||||
sudo systemctl edit ollama
|
||||
```
|
||||
|
||||
@@ -152,7 +152,7 @@ Use `OLLAMA_VERSION` environment variable with the install script to install a s
|
||||
For example:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.3.9 sh
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||
```
|
||||
|
||||
## Viewing logs
|
||||
@@ -186,3 +186,9 @@ sudo rm -r /usr/share/ollama
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
```
|
||||
|
||||
Remove installed libraries:
|
||||
|
||||
```shell
|
||||
sudo rm -rf /usr/local/lib/ollama
|
||||
```
|
||||
|
||||
@@ -28,7 +28,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||
|
||||
The format of the `Modelfile`:
|
||||
|
||||
```modelfile
|
||||
```
|
||||
# comment
|
||||
INSTRUCTION arguments
|
||||
```
|
||||
@@ -49,7 +49,7 @@ INSTRUCTION arguments
|
||||
|
||||
An example of a `Modelfile` creating a mario blueprint:
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM llama3.2
|
||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
PARAMETER temperature 1
|
||||
@@ -69,24 +69,30 @@ To use this:
|
||||
|
||||
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
||||
|
||||
```bash
|
||||
> ollama show --modelfile llama3.2
|
||||
# Modelfile generated by "ollama show"
|
||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
||||
# FROM llama3.2:latest
|
||||
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
```shell
|
||||
ollama show --modelfile llama3.2
|
||||
```
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
> **Output**:
|
||||
>
|
||||
> ```
|
||||
> # Modelfile generated by "ollama show"
|
||||
> # To build a new Modelfile based on this one, replace the FROM line with:
|
||||
> # FROM llama3.2:latest
|
||||
> FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||
> TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
>
|
||||
> {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
>
|
||||
> {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
>
|
||||
> {{ .Response }}<|eot_id|>"""
|
||||
> PARAMETER stop "<|start_header_id|>"
|
||||
> PARAMETER stop "<|end_header_id|>"
|
||||
> PARAMETER stop "<|eot_id|>"
|
||||
> PARAMETER stop "<|reserved_special_token"
|
||||
> ```
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ .Response }}<|eot_id|>"""
|
||||
PARAMETER stop "<|start_header_id|>"
|
||||
PARAMETER stop "<|end_header_id|>"
|
||||
PARAMETER stop "<|eot_id|>"
|
||||
PARAMETER stop "<|reserved_special_token"
|
||||
```
|
||||
|
||||
## Instructions
|
||||
|
||||
@@ -94,13 +100,13 @@ To view the Modelfile of a given model, use the `ollama show --modelfile` comman
|
||||
|
||||
The `FROM` instruction defines the base model to use when creating a model.
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM <model name>:<tag>
|
||||
```
|
||||
|
||||
#### Build from existing model
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM llama3.2
|
||||
```
|
||||
|
||||
@@ -111,7 +117,7 @@ Additional models can be found at:
|
||||
|
||||
#### Build from a Safetensors model
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM <model directory>
|
||||
```
|
||||
|
||||
@@ -125,7 +131,7 @@ Currently supported model architectures:
|
||||
|
||||
#### Build from a GGUF file
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM ./ollama-model.gguf
|
||||
```
|
||||
|
||||
@@ -136,7 +142,7 @@ The GGUF file location should be specified as an absolute path or relative to th
|
||||
|
||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||
|
||||
```modelfile
|
||||
```
|
||||
PARAMETER <parameter> <parametervalue>
|
||||
```
|
||||
|
||||
@@ -183,7 +189,7 @@ TEMPLATE """{{ if .System }}<|im_start|>system
|
||||
|
||||
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
|
||||
|
||||
```modelfile
|
||||
```
|
||||
SYSTEM """<system message>"""
|
||||
```
|
||||
|
||||
@@ -193,7 +199,7 @@ The `ADAPTER` instruction specifies a fine tuned LoRA adapter that should apply
|
||||
|
||||
#### Safetensor adapter
|
||||
|
||||
```modelfile
|
||||
```
|
||||
ADAPTER <path to safetensor adapter>
|
||||
```
|
||||
|
||||
@@ -204,7 +210,7 @@ Currently supported Safetensor adapters:
|
||||
|
||||
#### GGUF adapter
|
||||
|
||||
```modelfile
|
||||
```
|
||||
ADAPTER ./ollama-lora.gguf
|
||||
```
|
||||
|
||||
@@ -212,7 +218,7 @@ ADAPTER ./ollama-lora.gguf
|
||||
|
||||
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
||||
|
||||
```modelfile
|
||||
```
|
||||
LICENSE """
|
||||
<license text>
|
||||
"""
|
||||
@@ -222,7 +228,7 @@ LICENSE """
|
||||
|
||||
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
|
||||
|
||||
```modelfile
|
||||
```
|
||||
MESSAGE <role> <message>
|
||||
```
|
||||
|
||||
@@ -237,7 +243,7 @@ MESSAGE <role> <message>
|
||||
|
||||
#### Example conversation
|
||||
|
||||
```modelfile
|
||||
```
|
||||
MESSAGE user Is Toronto in Canada?
|
||||
MESSAGE assistant yes
|
||||
MESSAGE user Is Sacramento in Canada?
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# OpenAI compatibility
|
||||
|
||||
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
||||
> [!NOTE]
|
||||
> OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
||||
|
||||
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
||||
|
||||
@@ -59,8 +60,10 @@ embeddings = client.embeddings.create(
|
||||
input=["why is the sky blue?", "why is the grass green?"],
|
||||
)
|
||||
```
|
||||
|
||||
#### Structured outputs
|
||||
```py
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -144,7 +147,7 @@ const embedding = await openai.embeddings.create({
|
||||
|
||||
### `curl`
|
||||
|
||||
``` shell
|
||||
```shell
|
||||
curl http://localhost:11434/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
@@ -319,7 +322,7 @@ ollama pull llama3.2
|
||||
|
||||
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||
|
||||
```
|
||||
```shell
|
||||
ollama cp llama3.2 gpt-3.5-turbo
|
||||
```
|
||||
|
||||
@@ -343,7 +346,7 @@ curl http://localhost:11434/v1/chat/completions \
|
||||
|
||||
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
|
||||
|
||||
```modelfile
|
||||
```
|
||||
FROM <some model>
|
||||
PARAMETER num_ctx <context size>
|
||||
```
|
||||
|
||||
@@ -17,6 +17,7 @@ When you run Ollama in a **container**, the logs go to stdout/stderr in the cont
|
||||
```shell
|
||||
docker logs <container-name>
|
||||
```
|
||||
|
||||
(Use `docker ps` to find the container name)
|
||||
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
@@ -28,6 +29,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||
|
||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||
|
||||
```powershell
|
||||
$env:OLLAMA_DEBUG="1"
|
||||
& "ollama app.exe"
|
||||
@@ -49,12 +51,13 @@ Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||
|
||||
```
|
||||
```shell
|
||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||
```
|
||||
|
||||
You can see what features your CPU has with the following.
|
||||
```
|
||||
|
||||
```shell
|
||||
cat /proc/cpuinfo| grep flags | head -1
|
||||
```
|
||||
|
||||
@@ -62,8 +65,8 @@ cat /proc/cpuinfo| grep flags | head -1
|
||||
|
||||
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
|
||||
|
||||
```sh
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
@@ -47,6 +47,7 @@ If Ollama is already running, Quit the tray application and relaunch it from the
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
|
||||
```powershell
|
||||
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||
```
|
||||
@@ -54,7 +55,7 @@ Here's a quick example showing API access from `powershell`
|
||||
## Troubleshooting
|
||||
|
||||
Ollama on Windows stores files in a few different locations. You can view them in
|
||||
the explorer window by hitting `<cmd>+R` and type in:
|
||||
the explorer window by hitting `<Ctrl>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||
- *app.log* contains most resent logs from the GUI application
|
||||
- *server.log* contains the most recent server logs
|
||||
|
||||
@@ -165,6 +165,8 @@ var (
|
||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -40,8 +40,6 @@ func HumanBytes(b int64) string {
|
||||
}
|
||||
|
||||
switch {
|
||||
case value >= 100:
|
||||
return fmt.Sprintf("%d %s", int(value), unit)
|
||||
case value >= 10:
|
||||
return fmt.Sprintf("%d %s", int(value), unit)
|
||||
case value != math.Trunc(value):
|
||||
|
||||
91
format/bytes_test.go
Normal file
91
format/bytes_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHumanBytes(t *testing.T) {
|
||||
type testCase struct {
|
||||
input int64
|
||||
expected string
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
// Test bytes (B)
|
||||
{0, "0 B"},
|
||||
{1, "1 B"},
|
||||
{999, "999 B"},
|
||||
|
||||
// Test kilobytes (KB)
|
||||
{1000, "1 KB"},
|
||||
{1500, "1.5 KB"},
|
||||
{999999, "999 KB"},
|
||||
|
||||
// Test megabytes (MB)
|
||||
{1000000, "1 MB"},
|
||||
{1500000, "1.5 MB"},
|
||||
{999999999, "999 MB"},
|
||||
|
||||
// Test gigabytes (GB)
|
||||
{1000000000, "1 GB"},
|
||||
{1500000000, "1.5 GB"},
|
||||
{999999999999, "999 GB"},
|
||||
|
||||
// Test terabytes (TB)
|
||||
{1000000000000, "1 TB"},
|
||||
{1500000000000, "1.5 TB"},
|
||||
{1999999999999, "2.0 TB"},
|
||||
|
||||
// Test fractional values
|
||||
{1234, "1.2 KB"},
|
||||
{1234567, "1.2 MB"},
|
||||
{1234567890, "1.2 GB"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.expected, func(t *testing.T) {
|
||||
result := HumanBytes(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHumanBytes2(t *testing.T) {
|
||||
type testCase struct {
|
||||
input uint64
|
||||
expected string
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
// Test bytes (B)
|
||||
{0, "0 B"},
|
||||
{1, "1 B"},
|
||||
{1023, "1023 B"},
|
||||
|
||||
// Test kibibytes (KiB)
|
||||
{1024, "1.0 KiB"},
|
||||
{1536, "1.5 KiB"},
|
||||
{1048575, "1024.0 KiB"},
|
||||
|
||||
// Test mebibytes (MiB)
|
||||
{1048576, "1.0 MiB"},
|
||||
{1572864, "1.5 MiB"},
|
||||
{1073741823, "1024.0 MiB"},
|
||||
|
||||
// Test gibibytes (GiB)
|
||||
{1073741824, "1.0 GiB"},
|
||||
{1610612736, "1.5 GiB"},
|
||||
{2147483648, "2.0 GiB"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.expected, func(t *testing.T) {
|
||||
result := HumanBytes2(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,9 @@ func TestHumanNumber(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{0, "0"},
|
||||
{999, "999"},
|
||||
{1000, "1K"},
|
||||
{1001, "1K"},
|
||||
{1000000, "1M"},
|
||||
{125000000, "125M"},
|
||||
{500500000, "500.50M"},
|
||||
|
||||
@@ -153,19 +153,17 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||
return items
|
||||
}
|
||||
|
||||
func (ts Tensors) Layers() map[string]Layer {
|
||||
func (ts Tensors) GroupLayers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts.items {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if i := slices.Index(parts, "blk"); i > 0 {
|
||||
parts = append([]string{
|
||||
strings.Join(parts[:i], "."),
|
||||
strings.Join(parts[i:i+2], "."),
|
||||
}, parts[i+2:]...)
|
||||
} else if i == 0 {
|
||||
parts = append([]string{
|
||||
strings.Join(parts[i:i+2], "."),
|
||||
}, parts[i+2:]...)
|
||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||
if len(parts) > index+2 {
|
||||
// blk and mm should have a number after them, join it
|
||||
parts = append(
|
||||
[]string{strings.Join(parts[:index+2], ".")},
|
||||
parts[index+2:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
@@ -377,22 +375,22 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
|
||||
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = max(
|
||||
4*batch*(1+4*embedding+context*(1+heads)),
|
||||
@@ -407,7 +405,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
|
||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||
// mixtral 8x22b
|
||||
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||
partialOffload = max(
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||
@@ -424,11 +422,11 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
@@ -443,7 +441,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
)
|
||||
|
||||
var ropeFreqsCount uint64
|
||||
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
||||
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||
}
|
||||
@@ -547,20 +545,20 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (llm GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (llm GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := llm.KV()[fmt.Sprintf("%s.pooling_type", llm.KV().Architecture())]
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := llm.KV().EmbeddingHeadCountK()
|
||||
headCountV := llm.KV().EmbeddingHeadCountV()
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
|
||||
159
fs/ggml/ggml_test.go
Normal file
159
fs/ggml/ggml_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestTensorLayers(t *testing.T) {
|
||||
tensors := make(map[string]*Tensor)
|
||||
for _, name := range []string{
|
||||
"token_embd.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.0.attn_norm.weight",
|
||||
"blk.0.ffn_down.weight",
|
||||
"blk.0.ffn_gate.weight",
|
||||
"blk.0.ffn_up.weight",
|
||||
"blk.0.ffn_norm.weight",
|
||||
"output_norm.weight",
|
||||
"mm.0.bias",
|
||||
"mm.0.weight",
|
||||
"v.blk.0.attn_k.weight",
|
||||
"v.blk.0.attn_output.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
"v.blk.0.attn_v.weight",
|
||||
"v.blk.0.attn_norm.weight",
|
||||
"v.blk.0.ffn_down.weight",
|
||||
"v.blk.0.ffn_gate.weight",
|
||||
"v.blk.0.ffn_up.weight",
|
||||
"v.blk.0.ffn_norm.weight",
|
||||
"v.patch_embd.weight",
|
||||
"v.position_embd.gate",
|
||||
"v.position_embd.weight",
|
||||
} {
|
||||
tensors[name] = &Tensor{Name: name}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
items []*Tensor
|
||||
want map[string]Layer
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||
for k, v := range tensors {
|
||||
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
|
||||
if !yield(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
want: map[string]Layer{
|
||||
"blk.0": {
|
||||
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "vision",
|
||||
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||
for k, v := range tensors {
|
||||
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
|
||||
if !yield(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
want: map[string]Layer{
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "vision and text",
|
||||
items: slices.Collect(maps.Values(tensors)),
|
||||
want: map[string]Layer{
|
||||
"blk.0": {
|
||||
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Tensors{items: tt.items}.GroupLayers()
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("unexpected layers (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,10 @@ const (
|
||||
fileTypeIQ1_S
|
||||
fileTypeIQ4_NL
|
||||
fileTypeIQ3_S
|
||||
fileTypeIQ3_M
|
||||
fileTypeIQ2_S
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ2_M
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ1_M
|
||||
fileTypeBF16
|
||||
|
||||
@@ -93,12 +94,14 @@ func ParseFileType(s string) (fileType, error) {
|
||||
return fileTypeIQ4_NL, nil
|
||||
case "IQ3_S":
|
||||
return fileTypeIQ3_S, nil
|
||||
case "IQ3_M":
|
||||
return fileTypeIQ3_M, nil
|
||||
case "IQ2_S":
|
||||
return fileTypeIQ2_S, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ2_M":
|
||||
return fileTypeIQ2_M, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ1_M":
|
||||
return fileTypeIQ1_M, nil
|
||||
case "BF16":
|
||||
@@ -160,6 +163,8 @@ func (t fileType) String() string {
|
||||
return "IQ4_NL"
|
||||
case fileTypeIQ3_S:
|
||||
return "IQ3_S"
|
||||
case fileTypeIQ3_M:
|
||||
return "IQ3_M"
|
||||
case fileTypeIQ2_S:
|
||||
return "IQ2_S"
|
||||
case fileTypeIQ4_XS:
|
||||
|
||||
54
kvcache/cache.go
Normal file
54
kvcache/cache.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs.
|
||||
StartForward(ctx ml.Context, positions []int32, seqs []int) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
455
kvcache/causal.go
Normal file
455
kvcache/causal.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
//
|
||||
// The tensors are of shape embed dim, kv heads, batch size
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// starting location for data storage for this batch
|
||||
curLoc int
|
||||
|
||||
// size of the current batch
|
||||
curBatchSize int
|
||||
|
||||
// mask of the cache as used by this batch
|
||||
curMask ml.Tensor
|
||||
|
||||
// locations in the cache that are needed for this batch
|
||||
curCellRange cellRange
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// for each possible location in the cache, stores the position and set of sequences
|
||||
// that reference the data there
|
||||
cells []cacheCell
|
||||
|
||||
// maps from sequence to the range of locations where it is stored in the cache
|
||||
cellRanges map[int]cellRange
|
||||
|
||||
// ** cache data storage **
|
||||
|
||||
shiftFn shiftFn
|
||||
backend ml.Backend
|
||||
cacheCtx ml.Context
|
||||
keys, values []ml.Tensor
|
||||
}
|
||||
|
||||
type cacheCell struct {
|
||||
pos int32
|
||||
sequences []int
|
||||
}
|
||||
|
||||
type cellRange struct {
|
||||
min int
|
||||
max int
|
||||
}
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
|
||||
}
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{windowSize: windowSize, shiftFn: shift}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
c.DType = dtype
|
||||
c.Capacity = capacity
|
||||
c.cells = make([]cacheCell, capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||
c.curBatchSize = len(positions)
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range positions {
|
||||
seq := seqs[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.curMask, err = c.buildMask(ctx, positions, seqs)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newRange() cellRange {
|
||||
return cellRange{
|
||||
min: math.MaxInt,
|
||||
max: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first contiguous block of at least curBatchSize
|
||||
func (c *Causal) findStartLoc() (int, error) {
|
||||
var start, count int
|
||||
for i := range c.cells {
|
||||
if len(c.cells[i].sequences) == 0 {
|
||||
count++
|
||||
if count >= c.curBatchSize {
|
||||
return start, nil
|
||||
}
|
||||
} else {
|
||||
start = i + 1
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
}
|
||||
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||
// TODO(jessegross): This does not do padding, which is required for flash attention
|
||||
len := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, c.curBatchSize*len)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||
c.cells[j].pos < positions[i]-c.windowSize {
|
||||
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||
}
|
||||
|
||||
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||
for _, obj := range objs {
|
||||
if obj == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
||||
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
||||
|
||||
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) defrag() {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
|
||||
// Defrag strategy:
|
||||
// - Search for empty holes at the beginning of the cache,
|
||||
// filling them with active data starting at the end
|
||||
// - If there are contiguous elements that need to be moved,
|
||||
// combine them into a single operation by holding new moves
|
||||
// until we see that the next one is non-contiguous
|
||||
// - Fill up the context with the maximum number of operations it
|
||||
// can hold then compute that and continue with a new context
|
||||
//
|
||||
// We could try to optimize placement by grouping blocks from
|
||||
// the same sequences together but most likely the next forward
|
||||
// pass will disrupt this anyways, so the real world benefit
|
||||
// seems limited as this time.
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v).
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := ctx.MaxTensors() / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
src := len(c.cells) - 1
|
||||
|
||||
for dst := 0; dst < src; dst++ {
|
||||
if len(c.cells[dst].sequences) == 0 {
|
||||
for ; src > dst; src-- {
|
||||
if len(c.cells[src].sequences) != 0 {
|
||||
c.cells[dst] = c.cells[src]
|
||||
c.cells[src] = cacheCell{}
|
||||
|
||||
if pendingLen > 0 {
|
||||
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||
pendingSrc = src
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
|
||||
pendingSrc = src
|
||||
pendingDst = dst
|
||||
pendingLen = 1
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if moves >= maxMoves {
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
ctx = c.backend.NewContext()
|
||||
|
||||
moves = 0
|
||||
}
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
if moves > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
// Reset range metadata
|
||||
for seq := range c.cellRanges {
|
||||
seqRange := newRange()
|
||||
|
||||
for i, cell := range c.cells {
|
||||
if slices.Contains(cell.sequences, seq) {
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
if layer >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
)
|
||||
|
||||
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
value.Dim(0), value.Stride(1),
|
||||
value.Dim(1), value.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
)
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
if c.curBatchSize != key.Dim(2) {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
seqRange := newRange()
|
||||
|
||||
for i := range c.cells {
|
||||
// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||
if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||
}
|
||||
|
||||
if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||
c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[dstSeq] = seqRange
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
if c.shiftFn == nil {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
seqRange := c.cellRanges[seq]
|
||||
size := seqRange.max - seqRange.min + 1
|
||||
|
||||
offsets := make([]int32, size)
|
||||
for i := range offsets {
|
||||
cell := c.cells[seqRange.min+i]
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
}
|
||||
}
|
||||
|
||||
kShift, err := ctx.FromIntSlice(offsets, len(offsets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
size,
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
}
|
||||
|
||||
ctx.Compute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
var offset int32
|
||||
if endIndex != math.MaxInt32 {
|
||||
offset = beginIndex - endIndex
|
||||
}
|
||||
|
||||
seqRange := newRange()
|
||||
|
||||
for i := range c.cells {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
if c.cells[i].pos >= endIndex {
|
||||
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// TODO(jessegross): Need to be careful about data shared between sequences
|
||||
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
||||
}
|
||||
|
||||
c.cells[i].pos += offset
|
||||
}
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if seqRange == newRange() {
|
||||
delete(c.cellRanges, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
|
||||
if endIndex != math.MaxInt32 {
|
||||
err := c.shift(seq, endIndex+offset, offset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
510
kvcache/causal_test.go
Normal file
510
kvcache/causal_test.go
Normal file
@@ -0,0 +1,510 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
in []float32
|
||||
inShape []int
|
||||
seqs []int
|
||||
pos []int32
|
||||
expected []float32
|
||||
expectedShape []int
|
||||
expectedMask []float32
|
||||
}
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
inShape: []int{2, 3, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
expectedShape: []int{2, 3, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{115, 215, 125, 225, 135, 235},
|
||||
inShape: []int{2, 3, 1},
|
||||
seqs: []int{0},
|
||||
pos: []int32{4},
|
||||
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
expectedShape: []int{2, 3, 5},
|
||||
expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWA(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF32, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "SlidingWindow",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSequences(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 1, 1},
|
||||
pos: []int32{0, 1, 0, 1},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 1},
|
||||
pos: []int32{2, 2},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.Add(ctx, shift), nil
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 1, 1},
|
||||
pos: []int32{0, 1, 0, 1},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
||||
err := cache.Remove(0, 1, math.MaxInt32)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tests = []testCase{
|
||||
{
|
||||
name: "RemoveEnd",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 1},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
||||
err = cache.Remove(0, 0, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tests = []testCase{
|
||||
{
|
||||
name: "RemoveMiddle",
|
||||
in: []float32{7, 8},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{7, 8, 3, 4, 4},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestDefrag(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.Add(ctx, shift), nil
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
inShape: []int{1, 1, 16},
|
||||
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
||||
err := cache.Remove(0, 2, 4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = cache.Remove(0, 13, math.MaxInt32)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tests = []testCase{
|
||||
{
|
||||
name: "Defrag",
|
||||
in: []float32{17, 18, 19},
|
||||
inShape: []int{1, 1, 3},
|
||||
seqs: []int{0, 0, 0},
|
||||
pos: []int32{16, 17, 18},
|
||||
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
||||
cache.CopyPrefix(0, 1, 2)
|
||||
|
||||
tests = []testCase{
|
||||
{
|
||||
name: "Copy",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{3, 4},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, test.pos, test.seqs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
|
||||
context.Forward(out)
|
||||
context.Forward(mask)
|
||||
context.Compute(out, mask)
|
||||
|
||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct{}
|
||||
|
||||
func (b *testBackend) Config() ml.Config {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (b *testBackend) Get(name string) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (b *testBackend) NewContext() ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
func (b *testBackend) SystemInfo() string {
|
||||
return "not implemented"
|
||||
}
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
if len(shape) > 0 {
|
||||
total = 1
|
||||
for _, s := range shape {
|
||||
total *= s
|
||||
}
|
||||
}
|
||||
|
||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
f := make([]float32, len(s))
|
||||
for i := range f {
|
||||
f[i] = float32(s[i])
|
||||
}
|
||||
|
||||
out, _ := c.FromFloatSlice(f, shape...)
|
||||
out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Forward(ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) MaxTensors() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (c *testContext) Close() {}
|
||||
|
||||
type testTensor struct {
|
||||
dtype ml.DType
|
||||
elementSize int
|
||||
data []float32
|
||||
shape []int
|
||||
}
|
||||
|
||||
func (t *testTensor) Dim(n int) int {
|
||||
return t.shape[n]
|
||||
}
|
||||
|
||||
func (t *testTensor) Stride(n int) int {
|
||||
stride := t.elementSize
|
||||
for i := range n {
|
||||
stride *= t.shape[i]
|
||||
}
|
||||
|
||||
return stride
|
||||
}
|
||||
|
||||
func (t *testTensor) Shape() []int {
|
||||
return t.shape
|
||||
}
|
||||
|
||||
func (t *testTensor) DType() ml.DType {
|
||||
return t.dtype
|
||||
}
|
||||
|
||||
func (t *testTensor) Bytes() []byte {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Floats() []float32 {
|
||||
out := make([]float32, len(t.data))
|
||||
copy(out, t.data)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
for i := range out.data {
|
||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
var s []int
|
||||
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
s = []int{shape[0]}
|
||||
case 5:
|
||||
s = []int{shape[0], shape[2], shape[4]}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
context := &testContext{}
|
||||
|
||||
view := context.Zeros(t.dtype, s...).(*testTensor)
|
||||
view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
copy(t2.(*testTensor).data, t.data)
|
||||
return nil
|
||||
}
|
||||
97
kvcache/encoder.go
Normal file
97
kvcache/encoder.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// Encoder cache stores K and V tensors that are position independent
|
||||
//
|
||||
// The tensors can be of any shape and will be returned as they were stored
|
||||
// The mask is currently always nil
|
||||
//
|
||||
// Not currently safe for multiple sequences
|
||||
type EncoderCache struct {
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// if something is stored during this pass, this
|
||||
// will be the position (but there is no guarantee
|
||||
// anything will be stored)
|
||||
curPos int32
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// was something stored in the cache?
|
||||
encoderCached bool
|
||||
|
||||
// position of the cached data
|
||||
encoderPos int32
|
||||
|
||||
// ** cache data storage **
|
||||
|
||||
cacheCtx ml.Context
|
||||
keys, values []ml.Tensor
|
||||
}
|
||||
|
||||
func NewEncoderCache() *EncoderCache {
|
||||
return &EncoderCache{}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||
// The image is always in the first position
|
||||
c.curPos = positions[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetLayer(layer int) {
|
||||
if layer >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *EncoderCache) EncoderCached() bool {
|
||||
return c.encoderCached
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("encoder cache does not support multiple sequences")
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
c.encoderCached = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
93
kvcache/wrapper.go
Normal file
93
kvcache/wrapper.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// Wrapper cache is a container for multiple types of caches,
|
||||
// such as for the encoding and decoding portions of a model.
|
||||
type WrapperCache struct {
|
||||
// caches we are wrapping
|
||||
caches []Cache
|
||||
|
||||
// cache to be used for this layer
|
||||
curType int
|
||||
}
|
||||
|
||||
func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
return &WrapperCache{
|
||||
caches: caches,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
for _, cache := range c.caches {
|
||||
cache.Init(backend, dtype, capacity)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Close() {
|
||||
for _, cache := range c.caches {
|
||||
cache.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||
for i, cache := range c.caches {
|
||||
err := cache.StartForward(ctx, positions, seqs)
|
||||
if err != nil {
|
||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
for k := range positions {
|
||||
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.curType = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WrapperCache) SetLayer(layer int) {
|
||||
for _, cache := range c.caches {
|
||||
cache.SetLayer(layer)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
c.curType = layerType
|
||||
}
|
||||
|
||||
func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
return c.caches[c.curType]
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.caches[c.curType].Get(ctx)
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.caches[c.curType].Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
for _, cache := range c.caches {
|
||||
cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
for _, cache := range c.caches {
|
||||
err := cache.Remove(seq, beginIndex, endIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@ Ollama vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](h
|
||||
|
||||
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
||||
|
||||
```
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
|
||||
@@ -22,7 +22,7 @@ When updating to a newer base commit, the existing patches may not apply cleanly
|
||||
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
|
||||
```
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
|
||||
@@ -30,7 +30,7 @@ If there are conflicts, you will see an error message. Resolve the conflicts in
|
||||
|
||||
Once all patches are applied, commit the changes to the tracking repository.
|
||||
|
||||
```
|
||||
```shell
|
||||
make -f Makefile.sync format-patches sync
|
||||
```
|
||||
|
||||
@@ -38,13 +38,13 @@ make -f Makefile.sync format-patches sync
|
||||
|
||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||
|
||||
```
|
||||
```shell
|
||||
make -f Makefile.sync clean apply-patches
|
||||
```
|
||||
|
||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||
|
||||
```
|
||||
```shell
|
||||
make -f Makefile.sync format-patches
|
||||
```
|
||||
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "ba1cb19cdd0d92e012e0f6e009e0620f854b6afd";
|
||||
char const *LLAMA_COMMIT = "46e3556e01b824e52395fb050b29804b6cff2a7c";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
4
llama/build-info.cpp.in
Normal file
4
llama/build-info.cpp.in
Normal file
@@ -0,0 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "@FETCH_HEAD@";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
36
llama/llama.cpp/examples/llava/clip.cpp
vendored
36
llama/llama.cpp/examples/llava/clip.cpp
vendored
@@ -1235,35 +1235,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
new_clip->backend = ggml_backend_cuda_init(0);
|
||||
LOG_INF("%s: CLIP using CUDA backend\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
new_clip->backend = ggml_backend_metal_init();
|
||||
LOG_INF("%s: CLIP using Metal backend\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
new_clip->backend = ggml_backend_cann_init(0);
|
||||
LOG_INF("%s: CLIP using CANN backend\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
new_clip->backend = ggml_backend_vk_init(0);
|
||||
LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
new_clip->backend = ggml_backend_sycl_init(0);
|
||||
LOG_INF("%s: CLIP using SYCL backend\n", __func__);
|
||||
#endif
|
||||
|
||||
if (!new_clip->backend) {
|
||||
new_clip->backend = ggml_backend_cpu_init();
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
ggml_backend_t backend = ggml_backend_init_best();
|
||||
if (backend == nullptr) {
|
||||
LOG_ERR("%s: failed to initialize backend\n", __func__);
|
||||
clip_free(new_clip);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
LOG_INF("%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
new_clip->backend = backend;
|
||||
|
||||
// model size and capabilities
|
||||
{
|
||||
|
||||
@@ -199,21 +199,25 @@ func (c *Context) KvCacheDefrag() {
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
if embeddings == nil {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
embeddings := make([]float32, c.Model().NEmbd())
|
||||
_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
|
||||
return embeddings
|
||||
}
|
||||
|
||||
func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
|
||||
if embeddings == nil {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
embeddings := make([]float32, c.Model().NEmbd())
|
||||
_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
|
||||
return embeddings
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
|
||||
31
llama/mllama.cpp
vendored
31
llama/mllama.cpp
vendored
@@ -558,30 +558,15 @@ struct mllama_ctx *mllama_model_load(const char *fname, const int verbosity = 1)
|
||||
|
||||
mllama_ctx *new_mllama = new mllama_ctx{};
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
new_mllama->backend = ggml_backend_cuda_init(0);
|
||||
LOG("vision using CUDA backend");
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
new_mllama->backend = ggml_backend_metal_init();
|
||||
LOG("vision using Metal backend");
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
new_mllama->backend = ggml_backend_cann_init(0);
|
||||
LOG("vision using CANN backend");
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
new_mllama->backend = ggml_backend_vk_init(0);
|
||||
LOG("vision using Vulkan backend");
|
||||
#endif
|
||||
|
||||
if (!new_mllama->backend) {
|
||||
new_mllama->backend = ggml_backend_cpu_init();
|
||||
LOG("vision using CPU backend");
|
||||
ggml_backend_t backend = ggml_backend_init_best();
|
||||
if (backend == nullptr) {
|
||||
LOG("%s: failed to initialize backend\n", __func__);
|
||||
mllama_free(new_mllama);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
LOG("%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
new_mllama->backend = backend;
|
||||
|
||||
// load tensors
|
||||
{
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sat, 4 Jan 2025 22:52:48 -0800
|
||||
Subject: [PATCH] re-enable gpu for clip
|
||||
Subject: [PATCH] use dynamic backend loading for clip
|
||||
|
||||
---
|
||||
examples/llava/clip.cpp | 86 ++++++++++++++++++++---------------------
|
||||
1 file changed, 43 insertions(+), 43 deletions(-)
|
||||
examples/llava/clip.cpp | 74 +++++++++++++++--------------------------
|
||||
1 file changed, 27 insertions(+), 47 deletions(-)
|
||||
|
||||
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||
index b3c1829f..718052e1 100644
|
||||
index b3c1829f..86b91d5c 100644
|
||||
--- a/examples/llava/clip.cpp
|
||||
+++ b/examples/llava/clip.cpp
|
||||
@@ -8,25 +8,25 @@
|
||||
@@ -56,7 +56,7 @@ index b3c1829f..718052e1 100644
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
@@ -1235,30 +1235,30 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
@@ -1235,35 +1235,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,30 +84,19 @@ index b3c1829f..718052e1 100644
|
||||
-// new_clip->backend = ggml_backend_sycl_init(0);
|
||||
-// LOG_INF("%s: CLIP using SYCL backend\n", __func__);
|
||||
-//#endif
|
||||
+#ifdef GGML_USE_CUDA
|
||||
+ new_clip->backend = ggml_backend_cuda_init(0);
|
||||
+ LOG_INF("%s: CLIP using CUDA backend\n", __func__);
|
||||
+#endif
|
||||
+
|
||||
+#ifdef GGML_USE_METAL
|
||||
+ new_clip->backend = ggml_backend_metal_init();
|
||||
+ LOG_INF("%s: CLIP using Metal backend\n", __func__);
|
||||
+#endif
|
||||
+
|
||||
+#ifdef GGML_USE_CANN
|
||||
+ new_clip->backend = ggml_backend_cann_init(0);
|
||||
+ LOG_INF("%s: CLIP using CANN backend\n", __func__);
|
||||
+#endif
|
||||
+
|
||||
+#ifdef GGML_USE_VULKAN
|
||||
+ new_clip->backend = ggml_backend_vk_init(0);
|
||||
+ LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
|
||||
+#endif
|
||||
+
|
||||
+#ifdef GGML_USE_SYCL
|
||||
+ new_clip->backend = ggml_backend_sycl_init(0);
|
||||
+ LOG_INF("%s: CLIP using SYCL backend\n", __func__);
|
||||
+#endif
|
||||
-
|
||||
- if (!new_clip->backend) {
|
||||
- new_clip->backend = ggml_backend_cpu_init();
|
||||
- LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
+ ggml_backend_t backend = ggml_backend_init_best();
|
||||
+ if (backend == nullptr) {
|
||||
+ LOG_ERR("%s: failed to initialize backend\n", __func__);
|
||||
+ clip_free(new_clip);
|
||||
+ gguf_free(ctx);
|
||||
+ return nullptr;
|
||||
}
|
||||
+ LOG_INF("%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
+ new_clip->backend = backend;
|
||||
|
||||
if (!new_clip->backend) {
|
||||
new_clip->backend = ggml_backend_cpu_init();
|
||||
// model size and capabilities
|
||||
{
|
||||
@@ -8,7 +8,7 @@ Subject: [PATCH] sort devices by score
|
||||
1 file changed, 13 insertions(+), 8 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 899d16f2..ac5cda07 100644
|
||||
index 899d16f2..135f7df0 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
|
||||
@@ -29,7 +29,7 @@ index 899d16f2..ac5cda07 100644
|
||||
if (!reg) {
|
||||
return;
|
||||
}
|
||||
@@ -206,15 +206,15 @@ struct ggml_backend_registry {
|
||||
@@ -206,15 +206,20 @@ struct ggml_backend_registry {
|
||||
#endif
|
||||
backends.push_back({ reg, std::move(handle) });
|
||||
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
|
||||
@@ -45,10 +45,15 @@ index 899d16f2..ac5cda07 100644
|
||||
#endif
|
||||
- devices.push_back(device);
|
||||
+ devices.push_back({device, score});
|
||||
+ std::stable_sort(devices.begin(), devices.end(),
|
||||
+ [](const auto & a, const auto & b) {
|
||||
+ return a.second > b.second;
|
||||
+ }
|
||||
+ );
|
||||
}
|
||||
|
||||
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
||||
@@ -257,7 +257,7 @@ struct ggml_backend_registry {
|
||||
@@ -257,7 +262,7 @@ struct ggml_backend_registry {
|
||||
|
||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
|
||||
|
||||
@@ -57,7 +62,7 @@ index 899d16f2..ac5cda07 100644
|
||||
|
||||
return reg;
|
||||
}
|
||||
@@ -280,7 +280,7 @@ struct ggml_backend_registry {
|
||||
@@ -280,7 +285,7 @@ struct ggml_backend_registry {
|
||||
// remove devices
|
||||
devices.erase(
|
||||
std::remove_if(devices.begin(), devices.end(),
|
||||
@@ -66,17 +71,12 @@ index 899d16f2..ac5cda07 100644
|
||||
devices.end());
|
||||
|
||||
// remove backend
|
||||
@@ -338,7 +338,12 @@ size_t ggml_backend_dev_count() {
|
||||
@@ -338,7 +343,7 @@ size_t ggml_backend_dev_count() {
|
||||
|
||||
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
|
||||
GGML_ASSERT(index < ggml_backend_dev_count());
|
||||
- return get_reg().devices[index];
|
||||
+ auto devices = get_reg().devices;
|
||||
+ if (!std::is_heap(devices.begin(), devices.end())) {
|
||||
+ std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
|
||||
+ }
|
||||
+
|
||||
+ return devices[index].first;
|
||||
+ return get_reg().devices[index].first;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
|
||||
|
||||
55
llama/patches/0016-remove-sgemm-global-variables.patch
Normal file
55
llama/patches/0016-remove-sgemm-global-variables.patch
Normal file
@@ -0,0 +1,55 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 9 Feb 2025 17:22:15 -0800
|
||||
Subject: [PATCH] remove sgemm global variables
|
||||
|
||||
removes the 'iq4nlt' global variable in sgemm.cpp that causes
|
||||
a runtime crash when calling dlopen on ggml-cpu libraries as
|
||||
its initialization depends on AVX instructions the host machine
|
||||
may not have
|
||||
---
|
||||
ggml/src/ggml-cpu/llamafile/sgemm.cpp | 17 +++++++++--------
|
||||
1 file changed, 9 insertions(+), 8 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
|
||||
index 8fce576c..3f260ce5 100644
|
||||
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
|
||||
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
|
||||
@@ -279,14 +279,6 @@ template <> inline __m256bh load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
-////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
-// CONSTANTS
|
||||
-
|
||||
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
-static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
-static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
||||
-#endif
|
||||
-
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FLOATING POINT MATRIX MULTIPLICATION
|
||||
|
||||
@@ -613,6 +605,14 @@ class tinyBLAS_Q0_AVX {
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
+ const int8_t kvalues_iq4nl[16] = {
|
||||
+ -127, -104, -83, -65,
|
||||
+ -49, -35, -22, -10,
|
||||
+ 1, 13, 25, 38,
|
||||
+ 53, 69, 89, 113
|
||||
+ };
|
||||
+
|
||||
+ iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
@@ -1037,6 +1037,7 @@ class tinyBLAS_Q0_AVX {
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
+ __m128i iq4nlt;
|
||||
};
|
||||
#endif // __AVX__
|
||||
|
||||
69
llama/patches/0017-try-catch-backend-load.patch
Normal file
69
llama/patches/0017-try-catch-backend-load.patch
Normal file
@@ -0,0 +1,69 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Tue, 11 Feb 2025 14:06:36 -0800
|
||||
Subject: [PATCH] try/catch backend load
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
|
||||
1 file changed, 23 insertions(+), 22 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 135f7df0..84b21dd8 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
- if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
- if (!handle && !silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
- if (handle) {
|
||||
+ try {
|
||||
+ if (entry.is_regular_file()) {
|
||||
+ std::wstring filename = entry.path().filename().wstring();
|
||||
+ std::wstring ext = entry.path().extension().wstring();
|
||||
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ if (!handle) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
- if (score_fn) {
|
||||
- int s = score_fn();
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
-#endif
|
||||
- if (s > best_score) {
|
||||
- best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
- }
|
||||
- } else {
|
||||
- if (!silent) {
|
||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
+ if (!score_fn) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ int s = score_fn();
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ if (s > best_score) {
|
||||
+ best_score = s;
|
||||
+ best_path = entry.path().wstring();
|
||||
}
|
||||
}
|
||||
}
|
||||
+ } catch (const std::exception & e) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
layers := f.Tensors().Layers()
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
layerSize = blk0.Size()
|
||||
@@ -410,7 +410,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
for _, layer := range ggml.Tensors().Layers() {
|
||||
for _, layer := range ggml.Tensors().GroupLayers() {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
@@ -431,7 +431,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := ggml.Tensors().Layers()["v"]["class_embd"]; ok {
|
||||
if _, ok := ggml.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
|
||||
@@ -90,8 +90,6 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
// The gpu list must be a single family.
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var err error
|
||||
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
@@ -231,19 +229,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
params = append(params, "--multiuser-cache")
|
||||
}
|
||||
|
||||
// get available libraries
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get libollama dir: %w", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(discover.LibOllamaPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read libollama dir: %w", err)
|
||||
}
|
||||
|
||||
libs := make(map[string]string)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil {
|
||||
for _, entry := range entries {
|
||||
libs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name())
|
||||
}
|
||||
}
|
||||
@@ -283,16 +271,24 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed ", "error", err)
|
||||
slog.Debug("ResolveTCPAddr failed, using random port")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := []string{"runner"}
|
||||
if envconfig.NewEngine() {
|
||||
finalParams = append(finalParams, "--ollama-engine")
|
||||
}
|
||||
finalParams = append(finalParams, params...)
|
||||
finalParams = append(finalParams, "--port", strconv.Itoa(port))
|
||||
|
||||
pathEnv := "LD_LIBRARY_PATH"
|
||||
if runtime.GOOS == "windows" {
|
||||
var pathEnv string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
pathEnv = "PATH"
|
||||
case "darwin":
|
||||
pathEnv = "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
pathEnv = "LD_LIBRARY_PATH"
|
||||
}
|
||||
|
||||
var libraryPaths []string
|
||||
@@ -324,9 +320,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to evaluate symlinks for executable path: %w", err)
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
|
||||
@@ -394,7 +389,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
strings.HasPrefix(ev, "HSA_") ||
|
||||
strings.HasPrefix(ev, "GGML_") ||
|
||||
strings.HasPrefix(ev, "PATH=") ||
|
||||
strings.HasPrefix(ev, "LD_LIBRARY_PATH=") {
|
||||
strings.HasPrefix(ev, "LD_LIBRARY_PATH=") ||
|
||||
strings.HasPrefix(ev, "DYLD_LIBRARY_PATH=") {
|
||||
filteredEnv = append(filteredEnv, ev)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,14 +6,14 @@ This app builds upon Ollama to provide a desktop experience for running models.
|
||||
|
||||
First, build the `ollama` binary:
|
||||
|
||||
```
|
||||
```shell
|
||||
cd ..
|
||||
go build .
|
||||
```
|
||||
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```
|
||||
```shell
|
||||
cd macapp
|
||||
npm install
|
||||
npm start
|
||||
|
||||
@@ -19,7 +19,7 @@ const config: ForgeConfig = {
|
||||
icon: './assets/icon.icns',
|
||||
extraResource: [
|
||||
path.join(__dirname, '../dist/darwin/ollama'),
|
||||
...fs.readdirSync(path.join(__dirname, '../dist/darwin/amd64')).map(f => path.join(__dirname, '../dist/darwin/amd64', f)),
|
||||
...fs.readdirSync(path.join(__dirname, '../dist/darwin-amd64/lib/ollama')).map(f => path.join(__dirname, '../dist/darwin-amd64/lib/ollama', f)),
|
||||
path.join(__dirname, './assets/iconTemplate.png'),
|
||||
path.join(__dirname, './assets/iconTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconUpdateTemplate.png'),
|
||||
|
||||
14
main.go
14
main.go
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -9,5 +11,15 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ type Backend interface {
|
||||
Config() Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File) (Backend, error))
|
||||
@@ -48,15 +50,16 @@ type Context interface {
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Compute(Tensor) Tensor
|
||||
Close() error
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
Close()
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int64
|
||||
Stride(n int) int64
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int64
|
||||
Shape() []int
|
||||
DType() DType
|
||||
|
||||
Bytes() []byte
|
||||
@@ -65,6 +68,7 @@ type Tensor interface {
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
@@ -78,13 +82,13 @@ type Tensor interface {
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int64) Tensor
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int64) Tensor
|
||||
Unpad(ctx Context, shape ...int64) Tensor
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
Unpad(ctx Context, shape ...int) Tensor
|
||||
|
||||
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
@@ -110,13 +114,13 @@ func mul[T number](s ...T) T {
|
||||
|
||||
type DumpOptions struct {
|
||||
// Items is the number of elements to print at the beginning and end of each dimension.
|
||||
Items int64
|
||||
Items int
|
||||
|
||||
// Precision is the number of decimal places to print. Applies to float32 and float64.
|
||||
Precision int
|
||||
}
|
||||
|
||||
func Dump(t Tensor, opts ...DumpOptions) string {
|
||||
func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
if len(opts) < 1 {
|
||||
opts = append(opts, DumpOptions{
|
||||
Items: 3,
|
||||
@@ -126,18 +130,28 @@ func Dump(t Tensor, opts ...DumpOptions) string {
|
||||
|
||||
switch t.DType() {
|
||||
case DTypeF32:
|
||||
return dump[[]float32](t, opts[0])
|
||||
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16:
|
||||
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeI32:
|
||||
return dump[[]int32](t, opts[0])
|
||||
return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
|
||||
return strconv.FormatInt(int64(i), 10)
|
||||
})
|
||||
default:
|
||||
return "<unsupported>"
|
||||
}
|
||||
}
|
||||
|
||||
func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
||||
bts := t.Bytes()
|
||||
if bts == nil {
|
||||
return "<nil>"
|
||||
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
if t.Bytes() == nil {
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
@@ -148,16 +162,16 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
||||
shape := t.Shape()
|
||||
|
||||
var sb strings.Builder
|
||||
var f func([]int64, int64)
|
||||
f = func(dims []int64, stride int64) {
|
||||
var f func([]int, int)
|
||||
f = func(dims []int, stride int) {
|
||||
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
fmt.Fprint(&sb, "[")
|
||||
defer func() { fmt.Fprint(&sb, "]") }()
|
||||
for i := int64(0); i < dims[0]; i++ {
|
||||
if i >= opts.Items && i < dims[0]-opts.Items {
|
||||
for i := 0; i < dims[0]; i++ {
|
||||
if i >= items && i < dims[0]-items {
|
||||
fmt.Fprint(&sb, "..., ")
|
||||
// skip to next printable element
|
||||
skip := dims[0] - 2*opts.Items
|
||||
skip := dims[0] - 2*items
|
||||
if len(dims) > 1 {
|
||||
stride += mul(append(dims[1:], skip)...)
|
||||
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
@@ -170,7 +184,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
||||
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
}
|
||||
} else {
|
||||
fmt.Fprint(&sb, s[stride+i])
|
||||
fmt.Fprint(&sb, fn(s[stride+i]))
|
||||
if i < dims[0]-1 {
|
||||
fmt.Fprint(&sb, ", ")
|
||||
}
|
||||
@@ -185,7 +199,8 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeF32 DType = iota
|
||||
DTypeOther DType = iota
|
||||
DTypeF32
|
||||
DTypeF16
|
||||
DTypeI32
|
||||
DTypeOther
|
||||
)
|
||||
|
||||
@@ -1,16 +1,30 @@
|
||||
package ggml
|
||||
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
// #include <stdlib.h>
|
||||
// #include <stdint.h>
|
||||
// #include "ggml.h"
|
||||
// #include "ggml-cpu.h"
|
||||
// #include "ggml-backend.h"
|
||||
/*
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
|
||||
static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
|
||||
|
||||
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||
COMPILER inline get_compiler() {
|
||||
#if defined(__clang__)
|
||||
return COMP_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMP_GCC;
|
||||
#else
|
||||
return UNKNOWN_COMPILER;
|
||||
#endif
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -23,7 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
type device struct {
|
||||
@@ -198,10 +212,9 @@ func (b *Backend) Get(name string) ml.Tensor {
|
||||
|
||||
func (b *Backend) NewContext() ml.Context {
|
||||
nodes := max(8192, len(b.meta.Tensors().Items())*5)
|
||||
bts := make([]byte, C.size_t(nodes)*C.ggml_tensor_overhead()+C.ggml_graph_overhead_custom(C.size_t(nodes), false))
|
||||
c := C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_buffer: unsafe.Pointer(&bts[0]),
|
||||
mem_size: C.size_t(len(bts)),
|
||||
mem_buffer: nil,
|
||||
mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
|
||||
no_alloc: true,
|
||||
})
|
||||
|
||||
@@ -243,15 +256,35 @@ func (c *Context) Forward(t ml.Tensor) {
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
}
|
||||
|
||||
func (c *Context) Compute(t ml.Tensor) ml.Tensor {
|
||||
c.Forward(t)
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
||||
|
||||
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
||||
needSync := true
|
||||
sync := func() {
|
||||
if needSync {
|
||||
C.ggml_backend_sched_synchronize(c.sched)
|
||||
needSync = false
|
||||
}
|
||||
}
|
||||
|
||||
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
||||
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
return t
|
||||
for _, t := range tensors {
|
||||
if C.ggml_nbytes(t.(*Tensor).t) > 0 {
|
||||
t.(*Tensor).sync = sync
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) MaxTensors() int {
|
||||
return c.nodes
|
||||
}
|
||||
|
||||
func shapeToGGML(shape []int) *C.int64_t {
|
||||
sh := make([]C.int64_t, len(shape))
|
||||
for i, s := range shape {
|
||||
sh[i] = (C.int64_t)(s)
|
||||
}
|
||||
|
||||
return &sh[0]
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
@@ -268,9 +301,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeF16:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
@@ -283,6 +318,13 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
n := len(s)
|
||||
|
||||
if n == 0 {
|
||||
var shape C.int64_t = 0
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||
return &Tensor{t: t}, nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
n /= v
|
||||
}
|
||||
@@ -291,7 +333,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
|
||||
}
|
||||
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
@@ -306,15 +348,16 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
|
||||
}
|
||||
|
||||
func (c *Context) Close() error {
|
||||
C.ggml_backend_sched_free(c.sched)
|
||||
C.ggml_free(c.ctx)
|
||||
return nil
|
||||
func (c *Context) Close() {
|
||||
if c != nil {
|
||||
C.ggml_backend_sched_free(c.sched)
|
||||
C.ggml_free(c.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
t *C.struct_ggml_tensor
|
||||
data []byte
|
||||
sync func()
|
||||
}
|
||||
|
||||
func (t *Tensor) LogValue() slog.Value {
|
||||
@@ -325,16 +368,16 @@ func (t *Tensor) LogValue() slog.Value {
|
||||
)
|
||||
}
|
||||
|
||||
func (t *Tensor) Dim(n int) int64 {
|
||||
return int64(t.t.ne[n])
|
||||
func (t *Tensor) Dim(n int) int {
|
||||
return int(t.t.ne[n])
|
||||
}
|
||||
|
||||
func (t *Tensor) Stride(n int) int64 {
|
||||
return int64(t.t.nb[n])
|
||||
func (t *Tensor) Stride(n int) int {
|
||||
return int(t.t.nb[n])
|
||||
}
|
||||
|
||||
func (t *Tensor) Shape() []int64 {
|
||||
shape := make([]int64, C.ggml_n_dims(t.t))
|
||||
func (t *Tensor) Shape() []int {
|
||||
shape := make([]int, C.ggml_n_dims(t.t))
|
||||
for i := range shape {
|
||||
shape[i] = t.Dim(i)
|
||||
}
|
||||
@@ -342,18 +385,23 @@ func (t *Tensor) Shape() []int64 {
|
||||
return shape
|
||||
}
|
||||
|
||||
func (t *Tensor) Bytes() []byte {
|
||||
if bts := C.ggml_get_data(t.t); bts != nil {
|
||||
return C.GoBytes(bts, C.int(C.ggml_nbytes(t.t)))
|
||||
func (t *Tensor) Bytes() (data []byte) {
|
||||
if t.sync != nil {
|
||||
data = make([]byte, C.ggml_nbytes(t.t))
|
||||
|
||||
t.sync()
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) Floats() (f32s []float32) {
|
||||
if t.data != nil {
|
||||
f32s = make([]float32, C.ggml_nelements(t.t))
|
||||
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
|
||||
func (t *Tensor) Floats() (data []float32) {
|
||||
if t.sync != nil {
|
||||
data = make([]float32, C.ggml_nelements(t.t))
|
||||
|
||||
t.sync()
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -363,6 +411,8 @@ func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
return ml.DTypeF32
|
||||
case C.GGML_TYPE_F16:
|
||||
return ml.DTypeF16
|
||||
case C.GGML_TYPE_I32:
|
||||
return ml.DTypeI32
|
||||
default:
|
||||
@@ -408,6 +458,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
@@ -421,7 +480,7 @@ func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
@@ -453,7 +512,7 @@ func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
@@ -494,7 +553,7 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
@@ -545,9 +604,14 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
ropeFactors = &Tensor{}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
if C.ggml_is_quantized(t.t._type) {
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
131072, // YaRN n_ctx_train
|
||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
||||
@@ -578,3 +642,34 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
case C.COMP_UNKNOWN:
|
||||
compiler = "cgo(unknown_compiler)"
|
||||
case C.COMP_GCC:
|
||||
compiler = "cgo(gcc)"
|
||||
case C.COMP_CLANG:
|
||||
compiler = "cgo(clang)"
|
||||
}
|
||||
|
||||
var s string
|
||||
for i := range C.ggml_backend_reg_count() {
|
||||
reg := C.ggml_backend_reg_get(i)
|
||||
fName := C.CString("ggml_backend_get_features")
|
||||
defer C.free(unsafe.Pointer(fName))
|
||||
get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
|
||||
if get_features_fn != nil {
|
||||
s += C.GoString(C.ggml_backend_reg_name(reg))
|
||||
s += " : "
|
||||
for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
|
||||
s += C.GoString(features.name)
|
||||
s += " = "
|
||||
s += C.GoString(features.value)
|
||||
s += " | "
|
||||
}
|
||||
}
|
||||
}
|
||||
return s + compiler
|
||||
}
|
||||
|
||||
57
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
57
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -215,6 +215,11 @@ struct ggml_backend_registry {
|
||||
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
|
||||
#endif
|
||||
devices.push_back({device, score});
|
||||
std::stable_sort(devices.begin(), devices.end(),
|
||||
[](const auto & a, const auto & b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
||||
@@ -338,12 +343,7 @@ size_t ggml_backend_dev_count() {
|
||||
|
||||
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
|
||||
GGML_ASSERT(index < ggml_backend_dev_count());
|
||||
auto devices = get_reg().devices;
|
||||
if (!std::is_heap(devices.begin(), devices.end())) {
|
||||
std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
|
||||
}
|
||||
|
||||
return devices[index].first;
|
||||
return get_reg().devices[index].first;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
|
||||
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
if (entry.is_regular_file()) {
|
||||
std::wstring filename = entry.path().filename().wstring();
|
||||
std::wstring ext = entry.path().extension().wstring();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
if (!handle && !silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
}
|
||||
if (handle) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
std::wstring filename = entry.path().filename().wstring();
|
||||
std::wstring ext = entry.path().extension().wstring();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (score_fn) {
|
||||
int s = score_fn();
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
#endif
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path().wstring();
|
||||
}
|
||||
} else {
|
||||
if (!silent) {
|
||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
}
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path().wstring();
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package cpu
|
||||
|
||||
// #cgo CFLAGS: -Wno-implicit-function-declaration
|
||||
// #cgo CFLAGS: -O3 -Wno-implicit-function-declaration
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include
|
||||
// #cgo CPPFLAGS: -DGGML_USE_LLAMAFILE
|
||||
|
||||
@@ -279,14 +279,6 @@ template <> inline __m256bh load(const float *p) {
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FLOATING POINT MATRIX MULTIPLICATION
|
||||
|
||||
@@ -613,6 +605,14 @@ class tinyBLAS_Q0_AVX {
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
const int8_t kvalues_iq4nl[16] = {
|
||||
-127, -104, -83, -65,
|
||||
-49, -35, -22, -10,
|
||||
1, 13, 25, 38,
|
||||
53, 69, 89, 113
|
||||
};
|
||||
|
||||
iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
@@ -1037,6 +1037,7 @@ class tinyBLAS_Q0_AVX {
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
__m128i iq4nlt;
|
||||
};
|
||||
#endif // __AVX__
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"]
|
||||
|
||||
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f{vkq_size}.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-wmma-f16.cuh"
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
|
||||
]
|
||||
|
||||
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE({type});
|
||||
"""
|
||||
|
||||
|
||||
def get_short_name(long_quant_name):
|
||||
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
||||
|
||||
|
||||
def get_head_sizes(type_k, type_v):
|
||||
if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16":
|
||||
return [64, 128, 256]
|
||||
if type_k == "GGML_TYPE_F16":
|
||||
return [64, 128]
|
||||
return [128]
|
||||
|
||||
|
||||
for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for vkq_size in [16, 32]:
|
||||
for type_k in TYPES_KV:
|
||||
for type_v in TYPES_KV:
|
||||
for head_size in get_head_sizes(type_k, type_v):
|
||||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for kq_acc_t in ["half", "float"]:
|
||||
for cols_per_block in [8, 16, 32]:
|
||||
if kq_acc_t == "float" and cols_per_block == 8:
|
||||
continue
|
||||
|
||||
with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_WMMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
|
||||
continue
|
||||
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
|
||||
continue
|
||||
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
f.write(SOURCE_MMQ.format(type=type))
|
||||
@@ -41,36 +41,53 @@ func sink(level C.int, text *C.char, _ unsafe.Pointer) {
|
||||
}
|
||||
|
||||
var OnceLoad = sync.OnceFunc(func() {
|
||||
var lib struct{ name, defaultValue string }
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
slog.Warn("failed to get executable path", "error", err)
|
||||
exe = "."
|
||||
}
|
||||
|
||||
// PATH, LD_LIBRARY_PATH, and DYLD_LIBRARY_PATH are often
|
||||
// set by the parent process, however, use a default value
|
||||
// if the environment variable is not set.
|
||||
var name, value string
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "linux":
|
||||
lib.name = "LD_LIBRARY_PATH"
|
||||
lib.defaultValue = "/usr/local/lib:/usr/lib"
|
||||
case "darwin":
|
||||
// On macOS, DYLD_LIBRARY_PATH is often not set, so
|
||||
// we use the directory of the executable as the default.
|
||||
name = "DYLD_LIBRARY_PATH"
|
||||
value = filepath.Dir(exe)
|
||||
case "windows":
|
||||
lib.name = "PATH"
|
||||
lib.defaultValue = "."
|
||||
name = "PATH"
|
||||
value = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||
default:
|
||||
return
|
||||
name = "LD_LIBRARY_PATH"
|
||||
value = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv(lib.name)
|
||||
paths, ok := os.LookupEnv(name)
|
||||
if !ok {
|
||||
paths = lib.defaultValue
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
if _, ok := os.LookupEnv("DYLD_LIBRARY_PATH"); !ok {
|
||||
os.Setenv("DYLD_LIBRARY_PATH", paths)
|
||||
}
|
||||
paths = value
|
||||
}
|
||||
|
||||
split := filepath.SplitList(paths)
|
||||
visited := make(map[string]struct{}, len(split))
|
||||
for _, path := range split {
|
||||
abspath, _ := filepath.Abs(path)
|
||||
abspath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
slog.Error("failed to get absolute path", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if abspath != filepath.Dir(exe) && !strings.Contains(abspath, filepath.FromSlash("lib/ollama")) {
|
||||
slog.Debug("skipping path which is not part of ollama", "path", abspath)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := visited[abspath]; !ok {
|
||||
func() {
|
||||
cpath := C.CString(path)
|
||||
slog.Debug("ggml backend load all from path", "path", abspath)
|
||||
cpath := C.CString(abspath)
|
||||
defer C.free(unsafe.Pointer(cpath))
|
||||
C.ggml_backend_load_all_from_path(cpath)
|
||||
}()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
//go:build debug
|
||||
|
||||
package ggml
|
||||
|
||||
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
|
||||
import "C"
|
||||
@@ -1,160 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
_ "github.com/ollama/ollama/model/llama"
|
||||
_ "github.com/ollama/ollama/model/mllama"
|
||||
"github.com/ollama/ollama/sample"
|
||||
)
|
||||
|
||||
var args struct {
|
||||
n int
|
||||
debug bool
|
||||
image string
|
||||
cache bool
|
||||
}
|
||||
|
||||
func temp() error {
|
||||
flag.IntVar(&args.n, "n", 10, "number of samples")
|
||||
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
||||
flag.StringVar(&args.image, "image", "", "path to image file")
|
||||
flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
var prompt string
|
||||
if n := len(flag.Args()); n == 1 {
|
||||
bts, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompt = string(bts)
|
||||
} else if n > 1 {
|
||||
prompt = strings.Join(flag.Args()[1:], " ")
|
||||
} else {
|
||||
return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
|
||||
}
|
||||
|
||||
level := slog.LevelInfo
|
||||
if args.debug {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
|
||||
return attr
|
||||
},
|
||||
})))
|
||||
|
||||
m, err := model.New(flag.Arg(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inputIDs, err := m.(model.TextProcessor).Encode(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var opts []model.OptionsFunc
|
||||
if args.cache {
|
||||
opts = append(opts, model.WithCache(&cache.Simple{
|
||||
Capacity: 2048,
|
||||
DType: ml.DTypeF32,
|
||||
}))
|
||||
}
|
||||
|
||||
if args.image != "" {
|
||||
if err := func() error {
|
||||
f, err := os.Open(args.image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts = append(opts, model.WithImage(img))
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var offset int
|
||||
for range args.n {
|
||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f32s := logit.Floats()
|
||||
f64s := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
f64s[i] = float64(f32)
|
||||
}
|
||||
|
||||
// do sampling
|
||||
f64s, err = sample.Sample(f64s, sample.Greedy())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var outputIDs []int32
|
||||
for _, f64 := range f64s {
|
||||
if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
|
||||
outputIDs = append(outputIDs, int32(f64))
|
||||
}
|
||||
}
|
||||
|
||||
if len(outputIDs) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
s, err := m.(model.TextProcessor).Decode(outputIDs)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(s)
|
||||
|
||||
inputIDs = append(inputIDs, outputIDs...)
|
||||
if args.cache {
|
||||
offset = len(inputIDs) - 1
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := temp(); err != nil {
|
||||
fmt.Println("err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type TextProcessor struct {
|
||||
model.BytePairEncoding
|
||||
}
|
||||
|
||||
func newTextProcessor(c ml.Config) TextProcessor {
|
||||
return TextProcessor{
|
||||
BytePairEncoding: model.BytePairEncoding{
|
||||
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
Vocabulary: &model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestProcessText(t *testing.T) {
|
||||
ours, err := model.New(filepath.Join("testdata", "model.bin"))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
t.Skip("no model.bin")
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("decode", func(t *testing.T) {
|
||||
f, err := os.Open(filepath.Join("testdata", "theirs.json"))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
t.Skip("no theirs.json")
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var theirs [][]byte
|
||||
if err := json.NewDecoder(f).Decode(&theirs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for id := range theirs {
|
||||
ids := []int32{int32(id)}
|
||||
s, err := ours.(model.TextProcessor).Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(string(theirs[id]), s); diff != "" {
|
||||
t.Errorf("%d no match (-theirs +ours):\n%s", id, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("encode", func(t *testing.T) {
|
||||
f, err := os.Open(filepath.Join("..", "testdata", "inputs.json"))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
t.Skip("no inputs.json")
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var inputs []struct {
|
||||
Values []byte `json:"base64"`
|
||||
IDs []int32 `json:"ids"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&inputs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i, input := range inputs {
|
||||
if i == 45 {
|
||||
t.Skip("skip 45")
|
||||
}
|
||||
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
ids, err := ours.(model.TextProcessor).Encode(string(input.Values))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(input.IDs, ids, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("%s: no match (-theirs +ours):\n%s", input.Values, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
1
model/mllama/testdata/model.bin
vendored
1
model/mllama/testdata/model.bin
vendored
@@ -1 +0,0 @@
|
||||
/Users/michaelyang/git/ollama/library/nltpt/Llama-3.2-11B-Vision-Instruct/merged.gguf
|
||||
1
model/mllama/testdata/theirs.json
vendored
1
model/mllama/testdata/theirs.json
vendored
File diff suppressed because one or more lines are too long
174
model/model.go
174
model/model.go
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
@@ -15,106 +16,51 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
cache.Cache
|
||||
cache.Options
|
||||
}
|
||||
|
||||
func (c Cache) Sub(i int) Cache {
|
||||
if c.Cache != nil {
|
||||
return Cache{
|
||||
Cache: c.Cache.Sub(i),
|
||||
Options: c.Options,
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
|
||||
if c.Cache != nil {
|
||||
return c.Cache.Put(ctx, key, value, opts)
|
||||
}
|
||||
|
||||
return key, value
|
||||
}
|
||||
|
||||
// Options contains the inputs for a model forward pass
|
||||
type Options struct {
|
||||
inputs []int32
|
||||
|
||||
Offset int
|
||||
Inputs []int32
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
|
||||
Images []image.Image
|
||||
|
||||
Cache
|
||||
}
|
||||
|
||||
func (opts Options) Inputs() []int32 {
|
||||
return opts.inputs[opts.Offset:]
|
||||
}
|
||||
|
||||
func (opts Options) Positions() []int32 {
|
||||
positions := make([]int32, len(opts.inputs)-opts.Offset)
|
||||
for i := range positions {
|
||||
positions[i] = int32(opts.Offset + i)
|
||||
}
|
||||
|
||||
return positions
|
||||
}
|
||||
|
||||
type OptionsFunc func(Model, *Options)
|
||||
|
||||
func WithInputIDs(ids []int32) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.inputs = ids
|
||||
}
|
||||
}
|
||||
|
||||
func WithOffset(offset int) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Offset = offset
|
||||
opts.Cache.Position = offset
|
||||
}
|
||||
}
|
||||
|
||||
func WithImage(img image.Image) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Images = append(opts.Images, img)
|
||||
}
|
||||
}
|
||||
|
||||
func WithCache(c cache.Cache) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Cache = Cache{
|
||||
Cache: c,
|
||||
Options: cache.Options{
|
||||
Position: opts.Offset,
|
||||
},
|
||||
}
|
||||
}
|
||||
type config struct {
|
||||
Cache kvcache.Cache
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
type Base struct {
|
||||
b ml.Backend
|
||||
config
|
||||
}
|
||||
|
||||
// Backend returns the underlying backend that will run the model
|
||||
func (m *Base) Backend() ml.Backend {
|
||||
return m.b
|
||||
}
|
||||
|
||||
func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, Options) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
}
|
||||
|
||||
var models = make(map[string]func(ml.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
if _, ok := models[name]; ok {
|
||||
panic("model: model already registered")
|
||||
@@ -123,8 +69,9 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
models[name] = f
|
||||
}
|
||||
|
||||
func New(s string) (Model, error) {
|
||||
r, err := os.Open(s)
|
||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||
func New(modelPath string) (Model, error) {
|
||||
r, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -146,16 +93,15 @@ func New(s string) (Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(b, v))
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t, v = t.Elem(), v.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
allNil := true
|
||||
@@ -173,7 +119,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(Base{b: b}))
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag) [][]string
|
||||
fn = func(tags []Tag) (values [][]string) {
|
||||
@@ -199,24 +145,22 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
|
||||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
slog.Debug("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer {
|
||||
vvv := vv.Elem()
|
||||
if vv.IsNil() {
|
||||
vvv = reflect.New(tt.Elem())
|
||||
}
|
||||
|
||||
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
|
||||
vv.Set(f.Addr())
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
setPointer(base, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,6 +177,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
return v
|
||||
}
|
||||
|
||||
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv := v
|
||||
if v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
||||
v.Set(f.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string
|
||||
Alternate []string
|
||||
@@ -262,18 +226,30 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
|
||||
var opts Options
|
||||
for _, optsFunc := range optsFuncs {
|
||||
optsFunc(m, &opts)
|
||||
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||
if len(opts.Positions) != len(opts.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||
}
|
||||
|
||||
if len(opts.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
t, err := m.Forward(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ctx.Close()
|
||||
|
||||
return ctx.Compute(t), nil
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) {
|
||||
|
||||
var m fakeModel
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
@@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
|
||||
"output_norm.weight",
|
||||
"output.weight",
|
||||
},
|
||||
}, v))
|
||||
}}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
@@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
|
||||
m := fakeModel{}
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
},
|
||||
}, v))
|
||||
}}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
|
||||
@@ -3,6 +3,7 @@ package llama
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -10,7 +11,7 @@ import (
|
||||
|
||||
type Options struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
hiddenSize, numHeads, numKVHeads int64
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
@@ -28,28 +29,32 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
BytePairEncoding: model.BytePairEncoding{
|
||||
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
Vocabulary: &model.Vocabulary{
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
hiddenSize: int64(c.Uint("embedding_length")),
|
||||
numHeads: int64(c.Uint("attention.head_count")),
|
||||
numKVHeads: int64(c.Uint("attention.head_count_kv")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
@@ -59,7 +64,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@@ -74,14 +79,16 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
k, v = cache.Put(ctx, k, v, cache.Options)
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
@@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
@@ -109,7 +120,7 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
|
||||
m.Cache.SetLayer(i)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
@@ -15,16 +17,33 @@ type Model struct {
|
||||
Projector *nn.Linear `gguf:"mm.0"`
|
||||
|
||||
ImageProcessor
|
||||
TextProcessor
|
||||
}
|
||||
|
||||
const (
|
||||
crossAttentionLayer = iota
|
||||
selfAttentionLayer
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextProcessor: newTextProcessor(c),
|
||||
TextModel: newTextModel(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
@@ -64,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
|
||||
}
|
||||
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type TextSelfAttention struct {
|
||||
@@ -16,7 +16,7 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@@ -31,19 +31,16 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
key, value = cache.Put(ctx, key, value, cache.Options)
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
|
||||
if mask != nil {
|
||||
scores = scores.Add(ctx, mask)
|
||||
}
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
@@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
@@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct {
|
||||
MLP *TextMLP
|
||||
}
|
||||
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -94,23 +96,29 @@ type TextCrossAttention struct {
|
||||
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
||||
}
|
||||
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
query := ca.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
key := ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
value := ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// TODO cache key, value
|
||||
value = ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
@@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct {
|
||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||
}
|
||||
|
||||
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
||||
}
|
||||
|
||||
type TextDecoderLayer interface {
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextDecoder struct {
|
||||
Layers []TextDecoderLayer
|
||||
}
|
||||
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
for i, layer := range d.Layers {
|
||||
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
|
||||
layerType := selfAttentionLayer
|
||||
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
||||
layerType = crossAttentionLayer
|
||||
}
|
||||
|
||||
cache.SetLayer(i)
|
||||
cache.SetLayerType(layerType)
|
||||
|
||||
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +189,7 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
||||
type TextModelOptions struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
|
||||
hiddenSize, numHeads, numKVHeads int64
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
|
||||
@@ -189,7 +205,7 @@ type TextModel struct {
|
||||
*TextModelOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
@@ -212,9 +228,9 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
return &TextModel{
|
||||
Transformer: &TextDecoder{Layers: decoderLayers},
|
||||
TextModelOptions: &TextModelOptions{
|
||||
hiddenSize: int64(c.Uint("embedding_length")),
|
||||
numHeads: int64(c.Uint("attention.head_count")),
|
||||
numKVHeads: int64(c.Uint("attention.head_count_kv")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int64 = 1
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
@@ -99,7 +99,7 @@ func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermedi
|
||||
var intermediateHiddenStates []ml.Tensor
|
||||
for i, layer := range e.Layers {
|
||||
if slices.Contains(intermediateLayersIndices, uint32(i)) {
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int64{1}, hiddenState.Shape()...)...))
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, opts)
|
||||
@@ -131,7 +131,7 @@ type PrecomputedPositionEmbedding struct {
|
||||
TilePositionEmbeddingGate ml.Tensor `gguf:"tile_position_embd.gate"`
|
||||
}
|
||||
|
||||
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int64, opts *VisionModelOptions) ml.Tensor {
|
||||
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
|
||||
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
if e.PositionEmbeddingGate != nil {
|
||||
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
|
||||
@@ -149,7 +149,7 @@ func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, posi
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads, numTiles int64
|
||||
hiddenSize, numHeads, numTiles int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
|
||||
@@ -174,7 +174,7 @@ type VisionModel struct {
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRatioIDs ml.Tensor) ml.Tensor {
|
||||
numPatches := int64((m.imageSize / m.patchSize) * (m.imageSize / m.patchSize))
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
numPositions := numPatches
|
||||
if m.ClassEmbedding != nil {
|
||||
numPositions++
|
||||
@@ -185,7 +185,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, int(m.numTiles)-1)...).Concat(ctx, hiddenState, 1)
|
||||
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
|
||||
|
||||
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
|
||||
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
@@ -205,7 +205,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
|
||||
|
||||
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, int64(len(intermediateHiddenStates))*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, len(intermediateHiddenStates)*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
@@ -219,9 +219,9 @@ func newVisionModel(c ml.Config) *VisionModel {
|
||||
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
|
||||
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int64(c.Uint("vision.embedding_length")),
|
||||
numHeads: int64(c.Uint("vision.attention.head_count")),
|
||||
numTiles: int64(c.Uint("vision.max_num_tiles")),
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
numTiles: int(c.Uint("vision.max_num_tiles")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
6
model/models/models.go
Normal file
6
model/models/models.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
)
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -20,7 +21,7 @@ const (
|
||||
type TextProcessor interface {
|
||||
Encode(string) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(uint32, Special) bool
|
||||
Is(int32, Special) bool
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
@@ -29,7 +30,7 @@ type Vocabulary struct {
|
||||
Scores []uint32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS uint32
|
||||
BOS, EOS int32
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
@@ -41,7 +42,7 @@ type Vocabulary struct {
|
||||
merge map[string]int32
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Is(id uint32, special Special) bool {
|
||||
func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||
switch special {
|
||||
case SpecialBOS:
|
||||
return id == v.BOS
|
||||
@@ -99,23 +100,29 @@ func (v *Vocabulary) Merge(left, right string) int {
|
||||
}
|
||||
|
||||
type BytePairEncoding struct {
|
||||
Pretokenizer string
|
||||
|
||||
*Vocabulary
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
||||
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||
return BytePairEncoding{
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
var matches []string
|
||||
for m, _ := re.FindStringMatch(s); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
matches = append(matches, m.String())
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
@@ -138,9 +145,9 @@ type merge struct {
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.Vocabulary.SpecialVocabulary() {
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.Vocabulary.Encode(special)
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
@@ -173,13 +180,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// split fragment using pretokenizer
|
||||
splits, err := bpe.split(frag.value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, split := range splits {
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
@@ -197,7 +198,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.Vocabulary.Encode(sb.String()); id >= 0 {
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
|
||||
continue
|
||||
@@ -219,7 +220,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.Vocabulary.Merge(left, right)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -271,7 +272,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.Vocabulary.Encode(string(merge.runes)); id >= 0 {
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
|
||||
}
|
||||
@@ -286,7 +287,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.Vocabulary.Decode(id) {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
|
||||
@@ -1,227 +1,253 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestBytePairEncoding(t *testing.T) {
|
||||
// Create a simple test vocabulary
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"Hello",
|
||||
"World",
|
||||
"!",
|
||||
"How",
|
||||
"are",
|
||||
"you",
|
||||
"t",
|
||||
"o",
|
||||
"d",
|
||||
"a",
|
||||
"y",
|
||||
"to",
|
||||
"tod",
|
||||
"toda",
|
||||
"today",
|
||||
" ",
|
||||
},
|
||||
Types: []uint32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3}, // 3 for special token (space)
|
||||
Merges: []string{
|
||||
"to",
|
||||
"tod",
|
||||
"toda",
|
||||
"today",
|
||||
},
|
||||
BOS: 0,
|
||||
EOS: 1,
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bpe := BytePairEncoding{
|
||||
Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
Vocabulary: vocab,
|
||||
types := make([]uint32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []int32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple hello world",
|
||||
input: "Hello World!",
|
||||
want: []int32{0, 15, 1, 2}, // indexes in the vocabulary
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "just spaces",
|
||||
input: " ",
|
||||
want: []int32{15, 15, 15}, // space token repeated
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "today with merges",
|
||||
input: "today",
|
||||
want: []int32{14}, // should merge
|
||||
wantErr: false,
|
||||
},
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := bpe.Encode(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test round trip if encoding succeeded
|
||||
if err == nil {
|
||||
decoded, err := bpe.Decode(got)
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Encode(string(bts))
|
||||
if err != nil {
|
||||
t.Errorf("BytePairEncoding.Decode() error = %v", err)
|
||||
return
|
||||
b.Fatal(err)
|
||||
}
|
||||
// Note: The decoded string might not exactly match the input due to
|
||||
// tokenization/normalization, so we re-encode it to compare
|
||||
reEncoded, err := bpe.Encode(decoded)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("BytePairEncoding.Encode() error on round trip = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(reEncoded, got) {
|
||||
t.Errorf("Round trip failed: original tokens = %v, after round trip = %v", got, reEncoded)
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytePairEncodingSpecialTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"<s>",
|
||||
"</s>",
|
||||
"<pad>",
|
||||
"Hello",
|
||||
"World",
|
||||
},
|
||||
Types: []uint32{3, 3, 3, 1, 1}, // 3 for special tokens
|
||||
BOS: 0,
|
||||
EOS: 1,
|
||||
}
|
||||
|
||||
bpe := BytePairEncoding{
|
||||
Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
Vocabulary: vocab,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []int32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "text with special token at start",
|
||||
input: "<s>Hello",
|
||||
want: []int32{0, 3},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "text with special token at end",
|
||||
input: "World</s>",
|
||||
want: []int32{4, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special token in middle",
|
||||
input: "Hello<pad>World",
|
||||
want: []int32{3, 2, 4},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := bpe.Encode(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BytePairEncoding.Encode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("BytePairEncoding.Encode() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytePairEncodingSplit(t *testing.T) {
|
||||
bpe := BytePairEncoding{
|
||||
Pretokenizer: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic splitting",
|
||||
input: "Hello World!",
|
||||
want: []string{"Hello", " World", "!"},
|
||||
},
|
||||
{
|
||||
name: "contractions",
|
||||
input: "I'm don't won't",
|
||||
want: []string{"I", "'m", " don", "'t", " won", "'t"},
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "In 2024 there are 365 days",
|
||||
want: []string{"In", " ", "202", "4", " there", " are", " ", "365", " days"},
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "Hello!! ...world",
|
||||
want: []string{"Hello", "!!", " ...", "world"},
|
||||
},
|
||||
{
|
||||
name: "multiple spaces",
|
||||
input: "Hello World",
|
||||
want: []string{"Hello", " ", " World"},
|
||||
},
|
||||
{
|
||||
name: "newlines",
|
||||
input: "Hello\nWorld",
|
||||
want: []string{"Hello", "\n", "World"},
|
||||
},
|
||||
{
|
||||
name: "mixed case and punctuation",
|
||||
input: "Hello, WORLD!! How's it going?",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := bpe.split(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BytePairEncoding.split() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("BytePairEncoding.split() = %v, want %v", got, tt.want)
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
586
model/testdata/inputs.json
vendored
586
model/testdata/inputs.json
vendored
@@ -1,586 +0,0 @@
|
||||
[
|
||||
{
|
||||
"base64": "aWVkIDQgwr0gbW9udGhz",
|
||||
"ids": [
|
||||
1142,
|
||||
220,
|
||||
19,
|
||||
220,
|
||||
27154,
|
||||
4038
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "RsO8aHJlcg==",
|
||||
"ids": [
|
||||
37,
|
||||
51853,
|
||||
261
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "",
|
||||
"ids": []
|
||||
},
|
||||
{
|
||||
"base64": "IA==",
|
||||
"ids": [
|
||||
220
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICA=",
|
||||
"ids": [
|
||||
256
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICAg",
|
||||
"ids": [
|
||||
262
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "CQ==",
|
||||
"ids": [
|
||||
197
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "Cg==",
|
||||
"ids": [
|
||||
198
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "Cgo=",
|
||||
"ids": [
|
||||
271
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "CgoK",
|
||||
"ids": [
|
||||
1432
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "CQo=",
|
||||
"ids": [
|
||||
1602
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "SGVsbG8gd29ybGQ=",
|
||||
"ids": [
|
||||
9906,
|
||||
1917
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IEhlbGxvIHdvcmxk",
|
||||
"ids": [
|
||||
22691,
|
||||
1917
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "SGVsbG8gV29ybGQ=",
|
||||
"ids": [
|
||||
9906,
|
||||
4435
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IEhlbGxvIFdvcmxk",
|
||||
"ids": [
|
||||
22691,
|
||||
4435
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IEhlbGxvIFdvcmxkIQ==",
|
||||
"ids": [
|
||||
22691,
|
||||
4435,
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "SGVsbG8sIHdvcmxkIQ==",
|
||||
"ids": [
|
||||
9906,
|
||||
11,
|
||||
1917,
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IEhlbGxvLCB3b3JsZCE=",
|
||||
"ids": [
|
||||
22691,
|
||||
11,
|
||||
1917,
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IHRoaXMgaXMg8J+mmS5jcHA=",
|
||||
"ids": [
|
||||
420,
|
||||
374,
|
||||
11410,
|
||||
99,
|
||||
247,
|
||||
13,
|
||||
11055
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "dzA0OCA3dHVpamsgZHNkZmh1",
|
||||
"ids": [
|
||||
86,
|
||||
23904,
|
||||
220,
|
||||
22,
|
||||
83,
|
||||
2005,
|
||||
42908,
|
||||
11729,
|
||||
3013,
|
||||
17156
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "0L3QtdGJ0L4g0L3QsCDQkdGK0LvQs9Cw0YDRgdC60Lg=",
|
||||
"ids": [
|
||||
79862,
|
||||
102118,
|
||||
13373,
|
||||
64571,
|
||||
34694,
|
||||
3114,
|
||||
112203,
|
||||
80112
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "4Z6A4Z624Z6T4Z+L4Z6P4Z+C4Z6W4Z634Z6f4Z+B4Z6f4Z6i4Z624Z6F4Z6B4Z6b4Z6F4Z+B4Z6J",
|
||||
"ids": [
|
||||
21549,
|
||||
222,
|
||||
98629,
|
||||
241,
|
||||
45358,
|
||||
233,
|
||||
21549,
|
||||
237,
|
||||
45358,
|
||||
224,
|
||||
21549,
|
||||
244,
|
||||
21549,
|
||||
115,
|
||||
21549,
|
||||
253,
|
||||
45358,
|
||||
223,
|
||||
21549,
|
||||
253,
|
||||
21549,
|
||||
95,
|
||||
98629,
|
||||
227,
|
||||
21549,
|
||||
223,
|
||||
21549,
|
||||
249,
|
||||
21549,
|
||||
227,
|
||||
45358,
|
||||
223,
|
||||
21549,
|
||||
231
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "8J+agCAobm9ybWFsKSDwn5i24oCN8J+Mq++4jyAobXVsdGlwbGUgZW1vamlzIGNvbmNhdGVuYXRlZCkg4pyFIChvbmx5IGVtb2ppIHRoYXQgaGFzIGl0cyBvd24gdG9rZW4p",
|
||||
"ids": [
|
||||
9468,
|
||||
248,
|
||||
222,
|
||||
320,
|
||||
8416,
|
||||
8,
|
||||
27623,
|
||||
114,
|
||||
102470,
|
||||
9468,
|
||||
234,
|
||||
104,
|
||||
31643,
|
||||
320,
|
||||
36773,
|
||||
100166,
|
||||
98634,
|
||||
8,
|
||||
26602,
|
||||
227,
|
||||
320,
|
||||
3323,
|
||||
43465,
|
||||
430,
|
||||
706,
|
||||
1202,
|
||||
1866,
|
||||
4037,
|
||||
8
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "SGVsbG8=",
|
||||
"ids": [
|
||||
9906
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IEhlbGxv",
|
||||
"ids": [
|
||||
22691
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICBIZWxsbw==",
|
||||
"ids": [
|
||||
220,
|
||||
22691
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICAgSGVsbG8=",
|
||||
"ids": [
|
||||
256,
|
||||
22691
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICAgIEhlbGxv",
|
||||
"ids": [
|
||||
262,
|
||||
22691
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICAgIEhlbGxvCiAgICBIZWxsbw==",
|
||||
"ids": [
|
||||
262,
|
||||
22691,
|
||||
198,
|
||||
262,
|
||||
22691
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ICg=",
|
||||
"ids": [
|
||||
320
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "CiA9",
|
||||
"ids": [
|
||||
198,
|
||||
284
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "JyBlcmE=",
|
||||
"ids": [
|
||||
6,
|
||||
11639
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "SGVsbG8sIHknYWxsISBIb3cgYXJlIHlvdSDwn5iBID/miJHmg7PlnKhhcHBsZeW3peS9nDEzMTQxNTHlpKnvvZ4=",
|
||||
"ids": [
|
||||
9906,
|
||||
11,
|
||||
379,
|
||||
65948,
|
||||
0,
|
||||
2650,
|
||||
527,
|
||||
499,
|
||||
27623,
|
||||
223,
|
||||
949,
|
||||
37046,
|
||||
101067,
|
||||
19000,
|
||||
23182,
|
||||
102301,
|
||||
9263,
|
||||
18136,
|
||||
16,
|
||||
36827,
|
||||
21909
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "ISEhISEh",
|
||||
"ids": [
|
||||
17523,
|
||||
3001
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "Mw==",
|
||||
"ids": [
|
||||
18
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzM=",
|
||||
"ids": [
|
||||
1644
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMz",
|
||||
"ids": [
|
||||
8765
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMw==",
|
||||
"ids": [
|
||||
8765,
|
||||
18
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMzM=",
|
||||
"ids": [
|
||||
8765,
|
||||
1644
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMzMz",
|
||||
"ids": [
|
||||
8765,
|
||||
8765
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMzMzMw==",
|
||||
"ids": [
|
||||
8765,
|
||||
8765,
|
||||
18
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMzMzMzM=",
|
||||
"ids": [
|
||||
8765,
|
||||
8765,
|
||||
1644
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "MzMzMzMzMzMz",
|
||||
"ids": [
|
||||
8765,
|
||||
8765,
|
||||
8765
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "Q+G7rWEgVmnhu4d0",
|
||||
"ids": [
|
||||
34,
|
||||
91163,
|
||||
101798
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "IGRpc2NhcmRz",
|
||||
"ids": [
|
||||
2624,
|
||||
2402
|
||||
]
|
||||
},
|
||||
{
|
||||
"base64": "CiAKCiAKCgogCSAJCSAJCiAgCiAgIAogICAgCiAgICAgCvCfmoAgKG5vcm1hbCkg8J+YtuKAjfCfjKvvuI8gKG11bHRpcGxlIGVtb2ppcyBjb25jYXRlbmF0ZWQpIOKchSDwn6aZ8J+mmSAzIDMzIDMzMyAzMzMzIDMzMzMzIDMzMzMzMyAzMzMzMzMzIDMzMzMzMzMzIDMuMyAzLi4zIDMuLi4zIOGegOGetuGek+Gfi+Gej+GfguGeluGet+Gen+GfgeGen+GeouGetuGehfCfmIEgP+aIkeaDs+WcqGFwcGxl5bel5L2cMTMxNDE1MeWkqe+9niAtLS0tLS09PT09PT09INC90LXRidC+INC90LAg0JHRitC70LPQsNGA0YHQutC4ICcnJycnJ2BgYGBgYGAiIiIiLi4uLi4uISEhISEhPz8/Pz8/IEkndmUgYmVlbiAndG9sZCBoZSdzIHRoZXJlLCAnUkUgeW91IHN1cmU/ICdNIG5vdCBzdXJlIEknbGwgbWFrZSBpdCwgJ0QgeW91IGxpa2Ugc29tZSB0ZWE/IFdlJ1ZlIGEnbEw=",
|
||||
"ids": [
|
||||
198,
|
||||
4815,
|
||||
15073,
|
||||
66597,
|
||||
8004,
|
||||
1602,
|
||||
2355,
|
||||
79772,
|
||||
11187,
|
||||
9468,
|
||||
248,
|
||||
222,
|
||||
320,
|
||||
8416,
|
||||
8,
|
||||
27623,
|
||||
114,
|
||||
102470,
|
||||
9468,
|
||||
234,
|
||||
104,
|
||||
31643,
|
||||
320,
|
||||
36773,
|
||||
100166,
|
||||
98634,
|
||||
8,
|
||||
26602,
|
||||
227,
|
||||
11410,
|
||||
99,
|
||||
247,
|
||||
9468,
|
||||
99,
|
||||
247,
|
||||
220,
|
||||
18,
|
||||
220,
|
||||
1644,
|
||||
220,
|
||||
8765,
|
||||
220,
|
||||
8765,
|
||||
18,
|
||||
220,
|
||||
8765,
|
||||
1644,
|
||||
220,
|
||||
8765,
|
||||
8765,
|
||||
220,
|
||||
8765,
|
||||
8765,
|
||||
18,
|
||||
220,
|
||||
8765,
|
||||
8765,
|
||||
1644,
|
||||
220,
|
||||
18,
|
||||
13,
|
||||
18,
|
||||
220,
|
||||
18,
|
||||
497,
|
||||
18,
|
||||
220,
|
||||
18,
|
||||
1131,
|
||||
18,
|
||||
220,
|
||||
21549,
|
||||
222,
|
||||
98629,
|
||||
241,
|
||||
45358,
|
||||
233,
|
||||
21549,
|
||||
237,
|
||||
45358,
|
||||
224,
|
||||
21549,
|
||||
244,
|
||||
21549,
|
||||
115,
|
||||
21549,
|
||||
253,
|
||||
45358,
|
||||
223,
|
||||
21549,
|
||||
253,
|
||||
21549,
|
||||
95,
|
||||
98629,
|
||||
227,
|
||||
76460,
|
||||
223,
|
||||
949,
|
||||
37046,
|
||||
101067,
|
||||
19000,
|
||||
23182,
|
||||
102301,
|
||||
9263,
|
||||
18136,
|
||||
16,
|
||||
36827,
|
||||
21909,
|
||||
56560,
|
||||
54337,
|
||||
19175,
|
||||
102118,
|
||||
13373,
|
||||
64571,
|
||||
34694,
|
||||
3114,
|
||||
112203,
|
||||
80112,
|
||||
3436,
|
||||
106451,
|
||||
14196,
|
||||
14196,
|
||||
74694,
|
||||
3089,
|
||||
3089,
|
||||
29249,
|
||||
17523,
|
||||
3001,
|
||||
27708,
|
||||
7801,
|
||||
358,
|
||||
3077,
|
||||
1027,
|
||||
364,
|
||||
83,
|
||||
820,
|
||||
568,
|
||||
596,
|
||||
1070,
|
||||
11,
|
||||
364,
|
||||
793,
|
||||
499,
|
||||
2771,
|
||||
30,
|
||||
364,
|
||||
44,
|
||||
539,
|
||||
2771,
|
||||
358,
|
||||
3358,
|
||||
1304,
|
||||
433,
|
||||
11,
|
||||
364,
|
||||
35,
|
||||
499,
|
||||
1093,
|
||||
1063,
|
||||
15600,
|
||||
30,
|
||||
1226,
|
||||
6,
|
||||
43712,
|
||||
264,
|
||||
64966,
|
||||
43
|
||||
]
|
||||
}
|
||||
]
|
||||
128002
model/testdata/llama3.2/encoder.json
vendored
Normal file
128002
model/testdata/llama3.2/encoder.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
280147
model/testdata/llama3.2/vocab.bpe
vendored
Normal file
280147
model/testdata/llama3.2/vocab.bpe
vendored
Normal file
File diff suppressed because it is too large
Load Diff
63845
model/testdata/war-and-peace.txt
vendored
Normal file
63845
model/testdata/war-and-peace.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var finishReasonToolCalls = "tool_calls"
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
@@ -266,7 +268,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
@@ -279,6 +281,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
if toolCallSent {
|
||||
return &finishReasonToolCalls
|
||||
}
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
@@ -585,6 +590,7 @@ type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
@@ -634,11 +640,14 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := toChunk(w.id, chatResponse)
|
||||
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
@@ -13,7 +14,8 @@ type State interface {
|
||||
|
||||
type Progress struct {
|
||||
mu sync.Mutex
|
||||
w io.Writer
|
||||
// buffer output to minimize flickering on all terminals
|
||||
w *bufio.Writer
|
||||
|
||||
pos int
|
||||
|
||||
@@ -22,7 +24,7 @@ type Progress struct {
|
||||
}
|
||||
|
||||
func NewProgress(w io.Writer) *Progress {
|
||||
p := &Progress{w: w}
|
||||
p := &Progress{w: bufio.NewWriter(w)}
|
||||
go p.start()
|
||||
return p
|
||||
}
|
||||
@@ -47,26 +49,29 @@ func (p *Progress) stop() bool {
|
||||
func (p *Progress) Stop() bool {
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
fmt.Fprintln(p.w)
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
func (p *Progress) StopAndClear() bool {
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
// clear all progress lines
|
||||
for i := range p.pos {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
fmt.Fprint(p.w, "\033[2K\033[1G")
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[2K", "\033[1G")
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
@@ -81,30 +86,31 @@ func (p *Progress) render() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
fmt.Fprint(p.w, "\033[?2026h")
|
||||
defer fmt.Fprint(p.w, "\033[?2026l")
|
||||
|
||||
// clear already rendered progress lines
|
||||
for i := range p.pos {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
fmt.Fprint(p.w, "\033[2K\033[1G")
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[1G")
|
||||
|
||||
// render progress lines
|
||||
for i, state := range p.states {
|
||||
fmt.Fprint(p.w, state.String())
|
||||
fmt.Fprint(p.w, state.String(), "\033[K")
|
||||
if i < len(p.states)-1 {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
p.pos = len(p.states)
|
||||
p.w.Flush()
|
||||
}
|
||||
|
||||
func (p *Progress) start() {
|
||||
p.ticker = time.NewTicker(100 * time.Millisecond)
|
||||
// hide cursor
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
for range p.ticker.C {
|
||||
p.render()
|
||||
}
|
||||
|
||||
@@ -4,18 +4,18 @@
|
||||
|
||||
A minimial runner for loading a model and running inference via a http web server.
|
||||
|
||||
```
|
||||
```shell
|
||||
./runner -model <model binary>
|
||||
```
|
||||
|
||||
### Completion
|
||||
|
||||
```
|
||||
```shell
|
||||
curl -X POST -H "Content-Type: application/json" -d '{"prompt": "hi"}' http://localhost:8080/completion
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```
|
||||
```shell
|
||||
curl -X POST -H "Content-Type: application/json" -d '{"prompt": "turn me into an embedding"}' http://localhost:8080/embedding
|
||||
```
|
||||
@@ -1,10 +1,10 @@
|
||||
package runner
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func findStop(sequence string, stops []string) (bool, string) {
|
||||
func FindStop(sequence string, stops []string) (bool, string) {
|
||||
for _, stop := range stops {
|
||||
if strings.Contains(sequence, stop) {
|
||||
return true, stop
|
||||
@@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func containsStopSuffix(sequence string, stops []string) bool {
|
||||
func ContainsStopSuffix(sequence string, stops []string) bool {
|
||||
for _, stop := range stops {
|
||||
for i := 1; i <= len(stop); i++ {
|
||||
if strings.HasSuffix(sequence, stop[:i]) {
|
||||
@@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
@@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
return result, tokenTruncated
|
||||
}
|
||||
|
||||
func incompleteUnicode(token string) bool {
|
||||
func IncompleteUnicode(token string) bool {
|
||||
incomplete := false
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package common
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
@@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
||||
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := incompleteUnicode(tt.input)
|
||||
result := IncompleteUnicode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
@@ -1,4 +1,4 @@
|
||||
package runner
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
)
|
||||
|
||||
// input is an element of the prompt to process, either
|
||||
@@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
|
||||
var tokenTruncated bool
|
||||
origLen := len(seq.pendingResponses)
|
||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
||||
seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
|
||||
newLen := len(seq.pendingResponses)
|
||||
|
||||
// Update the cache based on the tokens that will be returned:
|
||||
@@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
if containsStopSuffix(sequence, seq.stop) {
|
||||
if common.ContainsStopSuffix(sequence, seq.stop) {
|
||||
continue
|
||||
}
|
||||
|
||||
if incompleteUnicode(sequence) {
|
||||
if common.IncompleteUnicode(sequence) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -844,8 +845,6 @@ func (s *Server) loadModel(
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
llama.BackendInit()
|
||||
|
||||
var err error
|
||||
s.model, err = llama.LoadModelFromFile(mpath, params)
|
||||
if err != nil {
|
||||
@@ -885,9 +884,6 @@ func (s *Server) loadModel(
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
if args[0] == "runner" {
|
||||
args = args[1:]
|
||||
}
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
||||
@@ -934,6 +930,8 @@ func Execute(args []string) error {
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
slog.Info("starting go runner")
|
||||
|
||||
llama.BackendInit()
|
||||
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
|
||||
|
||||
server := &Server{
|
||||
280
runner/ollamarunner/cache.go
Normal file
280
runner/ollamarunner/cache.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type InputCache struct {
|
||||
// context window size (per slot)
|
||||
numCtx int32
|
||||
|
||||
// does the cache store data or do we need to always send the full input?
|
||||
// note that when enabled is false the underlying cache may either be nil
|
||||
// or a non-nil dummy that doesn't actually store anything
|
||||
enabled bool
|
||||
|
||||
// individual KV caches
|
||||
slots []InputCacheSlot
|
||||
|
||||
// optimize cache eviction for multiple users
|
||||
multiUserCache bool
|
||||
|
||||
cache kvcache.Cache
|
||||
}
|
||||
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
||||
if kvSize/int32(numSlots) < 1 {
|
||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||
}
|
||||
|
||||
slots := make([]InputCacheSlot, numSlots)
|
||||
|
||||
for i := range slots {
|
||||
slots[i] = InputCacheSlot{
|
||||
Id: i,
|
||||
Inputs: make([]input, 0),
|
||||
}
|
||||
}
|
||||
|
||||
cache := model.Config().Cache
|
||||
if cache != nil {
|
||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
||||
}
|
||||
|
||||
return &InputCache{
|
||||
numCtx: kvSize / int32(numSlots),
|
||||
enabled: cache != nil,
|
||||
slots: slots,
|
||||
multiUserCache: multiUserCache,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func kvCacheTypeFromStr(s string) ml.DType {
|
||||
switch s {
|
||||
case "q8_0":
|
||||
panic("kv cache quantization not yet implemented")
|
||||
case "q4_0":
|
||||
panic("kv cache quantization not yet implemented")
|
||||
default:
|
||||
return ml.DTypeF16
|
||||
}
|
||||
}
|
||||
|
||||
func (c *InputCache) Close() {
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
||||
// these operations with each other and processBatch
|
||||
|
||||
type InputCacheSlot struct {
|
||||
// Index in the KV cache
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
|
||||
// last time this cache was used (as of start of processing)
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
|
||||
// In single-user scenarios, the longest cache slot works fine for getting good input
|
||||
// cache hit rates and it keeps the footprint of the cache small, which improves throughput.
|
||||
// For multiple users, the "best" cache slot produces better input cache hit rates
|
||||
// at the cost of worse performance when we miss the input cache.
|
||||
if !c.multiUserCache {
|
||||
slot, numPast, err = c.findLongestCacheSlot(prompt)
|
||||
} else {
|
||||
slot, numPast, err = c.findBestCacheSlot(prompt)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
if numPast == int32(len(prompt)) {
|
||||
// Leave one input to sample so we can get a response
|
||||
numPast--
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
if err != nil {
|
||||
// Some models don't support partial erasure
|
||||
err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
numPast = 0
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||||
"used", numPast, "remaining", int32(len(prompt))-numPast)
|
||||
|
||||
prompt = prompt[numPast:]
|
||||
slot.Inputs = slot.Inputs[:numPast]
|
||||
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
for i, s := range c.slots {
|
||||
if s.InUse {
|
||||
continue
|
||||
}
|
||||
|
||||
count := countCommonPrefix(s.Inputs, prompt)
|
||||
if count > longest {
|
||||
longest = count
|
||||
longestSlot = &c.slots[i]
|
||||
}
|
||||
}
|
||||
|
||||
if longestSlot == nil {
|
||||
return nil, 0, errors.New("no available cache slots")
|
||||
}
|
||||
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
for i, s := range c.slots {
|
||||
count := countCommonPrefix(s.Inputs, prompt)
|
||||
if count > longest {
|
||||
longest = count
|
||||
longestSlot = &c.slots[i]
|
||||
}
|
||||
|
||||
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
|
||||
oldest = s.lastUsed
|
||||
oldestSlot = &c.slots[i]
|
||||
}
|
||||
}
|
||||
|
||||
if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
if oldestSlot.InUse {
|
||||
return nil, 0, errors.New("no available cache slots")
|
||||
}
|
||||
|
||||
if len(oldestSlot.Inputs) != 0 {
|
||||
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
|
||||
"used", oldestSlot.lastUsed)
|
||||
}
|
||||
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
}
|
||||
}
|
||||
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input, b []input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
if i >= len(b) {
|
||||
break
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a[i], b[i]) {
|
||||
break
|
||||
}
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
if numKeep >= c.numCtx {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
inputLen := int32(len(slot.Inputs))
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if c.cache != nil {
|
||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := numKeep + discard; i < inputLen; i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
291
runner/ollamarunner/cache_test.go
Normal file
291
runner/ollamarunner/cache_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCountCommon(t *testing.T) {
|
||||
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
|
||||
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input
|
||||
t2 []input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input{{token: 1}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input{{image: imgA}},
|
||||
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input{{token: 1}, {image: imgA}},
|
||||
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input{},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input{},
|
||||
t2: []input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := countCommonPrefix(tt.t1, tt.t2)
|
||||
if result != tt.expected {
|
||||
t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCacheSlot(t *testing.T) {
|
||||
type expected struct {
|
||||
result int
|
||||
len int32
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
{
|
||||
name: "Extend",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
{
|
||||
name: "New",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
{
|
||||
name: "Fork",
|
||||
cache: InputCache{
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input{{token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
{
|
||||
name: "Evict",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}, {token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
{
|
||||
name: "In use",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("Longest-"+tt.name, func(t *testing.T) {
|
||||
result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
|
||||
if err != nil {
|
||||
t.Errorf("findLongestCacheSlot: err %v", err)
|
||||
} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
|
||||
t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||
result.Id, tt.longest.result, resultLen, tt.longest.len)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("Best-"+tt.name, func(t *testing.T) {
|
||||
result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
|
||||
if err != nil {
|
||||
t.Errorf("findBestCacheSlot: err %v", err)
|
||||
} else if result.Id != tt.best.result || resultLen != tt.best.len {
|
||||
t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||
result.Id, tt.best.result, resultLen, tt.best.len)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShiftDiscard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
numKeep int32
|
||||
inputLen int32
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Shift",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 2048,
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Max Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 2048,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "No Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 0,
|
||||
inputLen: 2048,
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "Truncate",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 5000,
|
||||
expected: 3973,
|
||||
},
|
||||
{
|
||||
name: "Truncate Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 5000,
|
||||
expected: 2953,
|
||||
},
|
||||
{
|
||||
name: "No Op",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 512,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := InputCache{numCtx: tt.numCtx}
|
||||
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
946
runner/ollamarunner/runner.go
Normal file
946
runner/ollamarunner/runner.go
Normal file
@@ -0,0 +1,946 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
// input is an element of the prompt to process, either a token or an image
|
||||
type input struct {
|
||||
token int32
|
||||
|
||||
image image.Image
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// batch index
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
|
||||
// number of tokens to predict
|
||||
numPredict int
|
||||
|
||||
// set of samplers to run on generated logits
|
||||
samplers []sample.Sampler
|
||||
|
||||
// channel to send back the embedding if embedding only
|
||||
embedding chan []float32
|
||||
|
||||
// stop sequences
|
||||
stop []string
|
||||
|
||||
// number of inputs to keep at the beginning when shifting context window
|
||||
numKeep int32
|
||||
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
doneReason string
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
samplers []sample.Sampler
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
return nil, errors.New("no input provided")
|
||||
}
|
||||
|
||||
if params.numKeep < 0 {
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
inputs = newInputs
|
||||
}
|
||||
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplers: params.samplers,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
// TODO(jessegross): This can sometimes trigger for matching text in the
|
||||
// user's prompt. We previously tried to avoid it by only looking for images
|
||||
// on image models. We don't have a clear indication now but it would be better
|
||||
// to properly escape it in any case.
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input{token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
if i < len(matches) {
|
||||
n, _ := strconv.Atoi(matches[i][1])
|
||||
|
||||
imageIndex := -1
|
||||
for j := range images {
|
||||
if images[j].ID == n {
|
||||
imageIndex = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputs = append(inputs, input{image: image})
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
|
||||
// loaded model
|
||||
model model.Model
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
|
||||
// number of simultaneous requests to handle
|
||||
parallel int
|
||||
|
||||
// maximum number of elements in a batch (per sequence)
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
|
||||
// indicates that data is ready for processing
|
||||
cond *sync.Cond
|
||||
|
||||
// the list of simultaneous sequences being evaluated
|
||||
seqs []*Sequence
|
||||
|
||||
// seqs can have a maximum of parallel entries, which
|
||||
// is enfoced by seqSem
|
||||
seqsSem *semaphore.Weighted
|
||||
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
}
|
||||
|
||||
func (s *Server) allNil() bool {
|
||||
for _, item := range s.seqs {
|
||||
if item != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := s.processBatch()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options model.Options
|
||||
imgSeq := -1
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input{}
|
||||
}
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if i >= s.batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
// TODO(jessegross): Image inputs need to be rethought - it's
|
||||
// it doesn't work well for different types of models or multiple sequences
|
||||
if input.image != nil {
|
||||
if len(seq.pendingInputs) != len(options.Images) {
|
||||
break
|
||||
}
|
||||
|
||||
if imgSeq != seqIdx && imgSeq != -1 {
|
||||
s.nextSeq = seqIdx
|
||||
break
|
||||
}
|
||||
|
||||
imgSeq = seqIdx
|
||||
options.Images = append(options.Images, input.image)
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
continue
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, input.token)
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(options.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
}
|
||||
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if len(options.Inputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
f32s := modelOutput.Floats()
|
||||
|
||||
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
|
||||
logits := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
logits[i] = float64(f32)
|
||||
}
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
s.removeSequence(i, "")
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(f32s) / len(options.Outputs)
|
||||
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): Sampler will output a single int32 in the future
|
||||
token := int32(tokens[0])
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
|
||||
var tokenTruncated bool
|
||||
origLen := len(seq.pendingResponses)
|
||||
seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
|
||||
newLen := len(seq.pendingResponses)
|
||||
|
||||
// Update the cache based on the tokens that will be returned:
|
||||
// - We have 1 token more than is currently in the cache because
|
||||
// the last one generated wasn't submitted to Decode
|
||||
// - Remove any stop sequences that we stripped out
|
||||
// - If truncateStop removed a portion of a token, drop that
|
||||
// - As defense-in-depth, if truncatedToken didn't find a stop token
|
||||
// remove the extra one that we added to the cache len
|
||||
tokenLen := len(seq.cache.Inputs) + 1
|
||||
tokenLen -= origLen - newLen
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
continue
|
||||
}
|
||||
|
||||
if common.ContainsStopSuffix(sequence, seq.stop) {
|
||||
continue
|
||||
}
|
||||
|
||||
if common.IncompleteUnicode(sequence) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, "connection")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func getSamplers(_ CompletionRequest) []sample.Sampler {
|
||||
// TODO(jessegross): Waiting for sampling code
|
||||
|
||||
/*samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar*/
|
||||
|
||||
return []sample.Sampler{sample.Greedy()}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
samplers: getSamplers(req),
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
close(seq.quit)
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numPredicted,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
var req EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
slog.Debug("embedding request", "content", req.Content)
|
||||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
mpath string,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
var err error
|
||||
s.model, err = model.New(mpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !s.cache.enabled && parallel > 1 {
|
||||
parallel = 1
|
||||
slog.Warn("model does not support caching, disabling parallel processing")
|
||||
}
|
||||
|
||||
s.parallel = parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
_ = fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
_ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
|
||||
_ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
level := slog.LevelInfo
|
||||
if *verbose {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
return attr
|
||||
},
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
slog.Info("starting ollama engine")
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// n-gpu-layers
|
||||
// main-gpu
|
||||
// flash-attn
|
||||
// threads
|
||||
// no-mmap
|
||||
// mlock
|
||||
// tensor-split
|
||||
|
||||
/*var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||
for _, s := range stringFloats {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
}
|
||||
}*/
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
fmt.Println("Listen error:", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
|
||||
httpServer := http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
log.Println("Server listening on", addr)
|
||||
if err := httpServer.Serve(listener); err != nil {
|
||||
log.Fatal("server error:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
24
runner/runner.go
Normal file
24
runner/runner.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
if args[0] == "runner" {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
if args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
|
||||
if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user