mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
111 Commits
v0.13.1
...
parth/decr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2abfb433 | ||
|
|
805ed4644c | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 | ||
|
|
7cc2a653f2 | ||
|
|
2584940016 | ||
|
|
c6d4c0c7f2 | ||
|
|
1ef4241727 | ||
|
|
68fafd3002 | ||
|
|
2b2cda7a2b | ||
|
|
3cfe9fe146 | ||
|
|
a23b559b4c | ||
|
|
33ee7168ba | ||
|
|
34d0c55ea5 | ||
|
|
53a5a9e9ae | ||
|
|
e30e08a7d6 | ||
|
|
12e2b3514a | ||
|
|
626af2d809 | ||
|
|
76912c062a | ||
|
|
6c3faafed2 | ||
|
|
e51dead636 | ||
|
|
d087e46bd1 | ||
|
|
37f6f3af24 | ||
|
|
e1bdc23dd2 | ||
|
|
2e78653ff9 | ||
|
|
f5f74e12c1 | ||
|
|
18fdcc94e5 | ||
|
|
7ad036992f | ||
|
|
172b5924af | ||
|
|
8852220f59 | ||
|
|
7325791599 | ||
|
|
522c11a763 | ||
|
|
0fadeffaee | ||
|
|
49a9c9ba6a | ||
|
|
1c094038bc | ||
|
|
a013693f80 | ||
|
|
f6a016f49d | ||
|
|
45c4739374 | ||
|
|
2dd029de12 | ||
|
|
903b1fc97f | ||
|
|
89eb795293 | ||
|
|
7e3ea813c1 | ||
|
|
7b95087b9d | ||
|
|
971d62595a | ||
|
|
ffbe8e076d | ||
|
|
2c639431b1 | ||
|
|
aacd1cb394 | ||
|
|
e3731fb160 | ||
|
|
8dbc9e7b68 | ||
|
|
abe67acf8a | ||
|
|
4ff8a691bc | ||
|
|
1b308e1d2a | ||
|
|
bd6c1d6b49 | ||
|
|
3af5d3b738 | ||
|
|
7730895158 | ||
|
|
de9ecfd01c | ||
|
|
95fdd8d619 | ||
|
|
9f7822851c | ||
|
|
9b2035d194 | ||
|
|
93d45d7a04 | ||
|
|
709f842457 | ||
|
|
2dfb74410d | ||
|
|
1eb5e75972 | ||
|
|
3475d915cb | ||
|
|
48e78e9be1 | ||
|
|
a838421ea3 | ||
|
|
1c4e85b4df | ||
|
|
dac4f17fea | ||
|
|
56b8fb024c | ||
|
|
b95693056c | ||
|
|
c34fc64688 | ||
|
|
7cf6f18c1f | ||
|
|
bbbb6b2a01 | ||
|
|
76f88caf43 | ||
|
|
2bccf8c624 | ||
|
|
0c5e5f6630 | ||
|
|
d475d1f081 | ||
|
|
d2f334c1f7 | ||
|
|
603ceefaa6 | ||
|
|
e082d60a24 | ||
|
|
5dae738067 | ||
|
|
0c78723174 | ||
|
|
5a41d69b2a | ||
|
|
c146a138e3 | ||
|
|
31b8c6a214 | ||
|
|
9191dfaf05 | ||
|
|
1108d8b34e | ||
|
|
7837a5bc7e | ||
|
|
0a844f8e96 | ||
|
|
a03223b86f | ||
|
|
0cf7794b16 | ||
|
|
854d40edc5 | ||
|
|
84a2cedf18 | ||
|
|
3f30836734 | ||
|
|
cc9555aff0 | ||
|
|
20aee96706 | ||
|
|
18b5958d46 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
|||||||
id: logs
|
id: logs
|
||||||
attributes:
|
attributes:
|
||||||
label: Relevant log output
|
label: Relevant log output
|
||||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||||
render: shell
|
render: shell
|
||||||
validations:
|
validations:
|
||||||
required: false
|
required: false
|
||||||
|
|||||||
30
.github/workflows/release.yaml
vendored
30
.github/workflows/release.yaml
vendored
@@ -16,13 +16,15 @@ jobs:
|
|||||||
outputs:
|
outputs:
|
||||||
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
||||||
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
||||||
|
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set environment
|
- name: Set environment
|
||||||
id: goflags
|
id: goflags
|
||||||
run: |
|
run: |
|
||||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
|
||||||
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
|
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
|
||||||
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
darwin-build:
|
darwin-build:
|
||||||
runs-on: macos-14-xlarge
|
runs-on: macos-14-xlarge
|
||||||
@@ -53,6 +55,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- run: |
|
- run: |
|
||||||
./scripts/build_darwin.sh
|
./scripts/build_darwin.sh
|
||||||
- name: Log build results
|
- name: Log build results
|
||||||
@@ -63,6 +68,7 @@ jobs:
|
|||||||
name: bundles-darwin
|
name: bundles-darwin
|
||||||
path: |
|
path: |
|
||||||
dist/*.tgz
|
dist/*.tgz
|
||||||
|
dist/*.tar.zst
|
||||||
dist/*.zip
|
dist/*.zip
|
||||||
dist/*.dmg
|
dist/*.dmg
|
||||||
|
|
||||||
@@ -185,7 +191,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ github.workspace }}\.ccache
|
path: ${{ github.workspace }}\.ccache
|
||||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
|
||||||
- name: Build target "${{ matrix.preset }}"
|
- name: Build target "${{ matrix.preset }}"
|
||||||
run: |
|
run: |
|
||||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
@@ -249,6 +255,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- name: Verify gcc is actually clang
|
- name: Verify gcc is actually clang
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference='Continue'
|
$ErrorActionPreference='Continue'
|
||||||
@@ -302,6 +311,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: depends-windows*
|
pattern: depends-windows*
|
||||||
@@ -360,13 +372,17 @@ jobs:
|
|||||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
|
- name: Deduplicate CUDA libraries
|
||||||
|
run: |
|
||||||
|
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
- run: |
|
- run: |
|
||||||
for COMPONENT in bin/* lib/ollama/*; do
|
for COMPONENT in bin/* lib/ollama/*; do
|
||||||
case "$COMPONENT" in
|
case "$COMPONENT" in
|
||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
@@ -381,13 +397,13 @@ jobs:
|
|||||||
done
|
done
|
||||||
- run: |
|
- run: |
|
||||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||||
done
|
done
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||||
path: |
|
path: |
|
||||||
*.tgz
|
*.tar.zst
|
||||||
|
|
||||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||||
docker-build-push:
|
docker-build-push:
|
||||||
@@ -520,7 +536,7 @@ jobs:
|
|||||||
- name: Upload release artifacts
|
- name: Upload release artifacts
|
||||||
run: |
|
run: |
|
||||||
pids=()
|
pids=()
|
||||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||||
echo "Uploading $payload"
|
echo "Uploading $payload"
|
||||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||||
pids[$!]=$!
|
pids[$!]=$!
|
||||||
|
|||||||
9
.github/workflows/test.yaml
vendored
9
.github/workflows/test.yaml
vendored
@@ -22,6 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
changed: ${{ steps.changes.outputs.changed }}
|
changed: ${{ steps.changes.outputs.changed }}
|
||||||
|
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -37,6 +38,7 @@ jobs:
|
|||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -83,7 +85,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||||
@@ -178,7 +180,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ github.workspace }}\.ccache
|
path: ${{ github.workspace }}\.ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
@@ -206,6 +208,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: 'go.mod'
|
go-version-file: 'go.mod'
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- uses: actions/setup-node@v4
|
- uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: '20'
|
node-version: '20'
|
||||||
|
|||||||
@@ -1,76 +1,51 @@
|
|||||||
version: "2"
|
version: "2"
|
||||||
linters:
|
linters:
|
||||||
default: none
|
|
||||||
enable:
|
enable:
|
||||||
- asasalint
|
- asasalint
|
||||||
- bidichk
|
- bidichk
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- copyloopvar
|
|
||||||
- errcheck
|
|
||||||
- errorlint
|
|
||||||
- exptostd
|
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
- govet
|
|
||||||
- ineffassign
|
|
||||||
- intrange
|
- intrange
|
||||||
- makezero
|
- makezero
|
||||||
- misspell
|
- misspell
|
||||||
- modernize
|
|
||||||
- nilerr
|
- nilerr
|
||||||
- nilnil
|
|
||||||
- nolintlint
|
- nolintlint
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- perfsprint
|
|
||||||
- prealloc
|
|
||||||
- sloglint
|
|
||||||
- staticcheck
|
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
|
||||||
- usestdlibvars
|
|
||||||
- usetesting
|
- usetesting
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
|
disable:
|
||||||
|
- errcheck
|
||||||
|
- usestdlibvars
|
||||||
settings:
|
settings:
|
||||||
errcheck:
|
govet:
|
||||||
exclude-functions:
|
disable:
|
||||||
- fmt.Fprintf
|
- unusedresult
|
||||||
perfsprint:
|
|
||||||
strconcat: false
|
|
||||||
concat-loop: false
|
|
||||||
staticcheck:
|
staticcheck:
|
||||||
checks:
|
checks:
|
||||||
- all
|
- all
|
||||||
# Using a deprecated function, variable, constant or field.
|
- -QF* # disable quick fix suggestions
|
||||||
# https://staticcheck.dev/docs/checks/#SA1019
|
|
||||||
- -SA1019
|
- -SA1019
|
||||||
# Incorrect or missing package comment.
|
- -ST1000 # package comment format
|
||||||
# https://staticcheck.dev/docs/checks/#ST1000
|
- -ST1003 # underscores in package names
|
||||||
- -ST1000
|
- -ST1005 # error strings should not be capitalized
|
||||||
# Poorly chosen identifier.
|
- -ST1012 # error var naming (ErrFoo)
|
||||||
# https://staticcheck.dev/docs/checks/#ST1003
|
- -ST1016 # receiver name consistency
|
||||||
- -ST1003
|
- -ST1020 # comment on exported function format
|
||||||
# The documentation of an exported function should start with the function's name.
|
- -ST1021 # comment on exported type format
|
||||||
# https://staticcheck.dev/docs/checks/#ST1020
|
- -ST1022 # comment on exported var format
|
||||||
- -ST1020
|
- -ST1023 # omit type from declaration
|
||||||
# The documentation of an exported type should start with type's name.
|
severity:
|
||||||
# https://staticcheck.dev/docs/checks/#ST1021
|
default: error
|
||||||
- -ST1021
|
rules:
|
||||||
# The documentation of an exported variable or constant should start with variable's name.
|
- linters:
|
||||||
# https://staticcheck.dev/docs/checks/#ST1022
|
- gofmt
|
||||||
- -ST1022
|
- goimports
|
||||||
usestdlibvars:
|
- intrange
|
||||||
http-method: false
|
severity: info
|
||||||
http-status-code: false
|
|
||||||
|
|
||||||
formatters:
|
formatters:
|
||||||
enable:
|
enable:
|
||||||
- gci
|
|
||||||
- gofmt
|
- gofmt
|
||||||
- gofumpt
|
- gofumpt
|
||||||
settings:
|
|
||||||
gci:
|
|
||||||
sections:
|
|
||||||
- standard
|
|
||||||
- default
|
|
||||||
- localmodule
|
|
||||||
|
|||||||
@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
|
|||||||
|
|
||||||
project(Ollama C CXX)
|
project(Ollama C CXX)
|
||||||
|
|
||||||
|
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||||
|
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||||
|
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||||
|
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||||
|
# target architecture.
|
||||||
|
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||||
|
# Single architecture specified
|
||||||
|
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||||
|
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||||
|
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||||
|
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
include(CheckLanguage)
|
include(CheckLanguage)
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
|
|||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||||
|
|
||||||
set(GGML_BUILD ON)
|
set(GGML_BUILD ON)
|
||||||
set(GGML_SHARED ON)
|
set(GGML_SHARED ON)
|
||||||
@@ -32,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
|||||||
set(GGML_CPU_ALL_VARIANTS ON)
|
set(GGML_CPU_ALL_VARIANTS ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
if(APPLE)
|
||||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||||
|
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||||
@@ -54,6 +71,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
|||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
|
# Define GGML version variables for shared library SOVERSION
|
||||||
|
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||||
|
set(GGML_VERSION_MAJOR 0)
|
||||||
|
set(GGML_VERSION_MINOR 0)
|
||||||
|
set(GGML_VERSION_PATCH 0)
|
||||||
|
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||||
|
|
||||||
set(GGML_CPU ON)
|
set(GGML_CPU ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
@@ -140,14 +164,56 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(Vulkan)
|
if(NOT APPLE)
|
||||||
if(Vulkan_FOUND)
|
find_package(Vulkan)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
if(Vulkan_FOUND)
|
||||||
install(TARGETS ggml-vulkan
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
RUNTIME_DEPENDENCIES
|
install(TARGETS ggml-vulkan
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||||
)
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||||
|
)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
|
if(MLX_ENGINE)
|
||||||
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||||
|
|
||||||
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
install(TARGETS mlx mlxc
|
||||||
|
RUNTIME_DEPENDENCIES
|
||||||
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||||
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||||
|
# Metal backend is only built for arm64, not x86_64
|
||||||
|
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||||
|
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||||
|
if(CUDAToolkit_FOUND)
|
||||||
|
file(GLOB CUDART_LIBS
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||||
|
if(CUDART_LIBS)
|
||||||
|
install(FILES ${CUDART_LIBS}
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
@@ -41,7 +41,7 @@
|
|||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -83,6 +83,28 @@
|
|||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX",
|
||||||
|
"inherits": [ "Default" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"MLX_ENGINE": "ON",
|
||||||
|
"OLLAMA_RUNNER_DIR": "mlx"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX CUDA 12",
|
||||||
|
"inherits": [ "MLX", "CUDA 12" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX CUDA 13",
|
||||||
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"buildPresets": [
|
"buildPresets": [
|
||||||
@@ -140,6 +162,21 @@
|
|||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
"configurePreset": "Vulkan"
|
"configurePreset": "Vulkan"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX",
|
||||||
|
"targets": [ "mlx", "mlxc" ],
|
||||||
|
"configurePreset": "MLX"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX CUDA 12",
|
||||||
|
"targets": [ "mlx", "mlxc" ],
|
||||||
|
"configurePreset": "MLX CUDA 12"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MLX CUDA 13",
|
||||||
|
"targets": [ "mlx", "mlxc" ],
|
||||||
|
"configurePreset": "MLX CUDA 13"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
37
Dockerfile
37
Dockerfile
@@ -131,8 +131,39 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
|||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --parallel --preset 'Vulkan' \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||||
|
|
||||||
|
FROM base AS mlx
|
||||||
|
ARG CUDA13VERSION=13.0
|
||||||
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||||
|
&& dnf install -y openblas-devel lapack-devel \
|
||||||
|
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||||
|
&& dnf install -y libnccl libnccl-devel
|
||||||
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
|
ARG PARALLEL
|
||||||
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||||
|
COPY go.mod go.sum .
|
||||||
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
|
RUN go mod download
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||||
|
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||||
|
COPY . .
|
||||||
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
|
ENV CGO_ENABLED=1
|
||||||
|
ARG CGO_CFLAGS
|
||||||
|
ARG CGO_CXXFLAGS
|
||||||
|
RUN mkdir -p dist/bin
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -153,6 +184,8 @@ FROM --platform=linux/amd64 scratch AS amd64
|
|||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||||
|
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||||
|
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||||
|
|
||||||
FROM --platform=linux/arm64 scratch AS arm64
|
FROM --platform=linux/arm64 scratch AS arm64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -171,7 +204,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
|||||||
|
|
||||||
FROM ubuntu:24.04
|
FROM ubuntu:24.04
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y ca-certificates libvulkan1 \
|
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
COPY --from=archive /bin /usr/bin
|
COPY --from=archive /bin /usr/bin
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
|
FETCH_HEAD=ec98e2002
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
|
|||||||
$(WORKDIR):
|
$(WORKDIR):
|
||||||
git clone $(UPSTREAM) $(WORKDIR)
|
git clone $(UPSTREAM) $(WORKDIR)
|
||||||
|
|
||||||
.PHONE: format-patches
|
.PHONY: format-patches
|
||||||
format-patches: llama/patches
|
format-patches: llama/patches
|
||||||
git -C $(WORKDIR) format-patch \
|
git -C $(WORKDIR) format-patch \
|
||||||
--no-signature \
|
--no-signature \
|
||||||
@@ -66,7 +66,11 @@ format-patches: llama/patches
|
|||||||
-o $(realpath $<) \
|
-o $(realpath $<) \
|
||||||
$(FETCH_HEAD)
|
$(FETCH_HEAD)
|
||||||
|
|
||||||
.PHONE: clean
|
.PHONY: clean
|
||||||
clean: checkout
|
clean: checkout
|
||||||
@git -C $(WORKDIR) am --abort || true
|
@git -C $(WORKDIR) am --abort || true
|
||||||
$(RM) llama/patches/.*.patched
|
$(RM) llama/patches/.*.patched
|
||||||
|
|
||||||
|
.PHONY: print-base
|
||||||
|
print-base:
|
||||||
|
@echo $(FETCH_HEAD)
|
||||||
@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||||
- [GoLamify](https://github.com/prasad89/golamify)
|
- [GoLamify](https://github.com/prasad89/golamify)
|
||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||||
|
|||||||
778
anthropic/anthropic.go
Normal file
778
anthropic/anthropic.go
Normal file
@@ -0,0 +1,778 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error types matching Anthropic API
|
||||||
|
type Error struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Type string `json:"type"` // always "error"
|
||||||
|
Error Error `json:"error"`
|
||||||
|
RequestID string `json:"request_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||||
|
func NewError(code int, message string) ErrorResponse {
|
||||||
|
var etype string
|
||||||
|
switch code {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
etype = "invalid_request_error"
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
etype = "authentication_error"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
etype = "permission_error"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
etype = "not_found_error"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
etype = "rate_limit_error"
|
||||||
|
case http.StatusServiceUnavailable, 529:
|
||||||
|
etype = "overloaded_error"
|
||||||
|
default:
|
||||||
|
etype = "api_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrorResponse{
|
||||||
|
Type: "error",
|
||||||
|
Error: Error{Type: etype, Message: message},
|
||||||
|
RequestID: generateID("req"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request types
|
||||||
|
|
||||||
|
// MessagesRequest represents an Anthropic Messages API request
|
||||||
|
type MessagesRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Messages []MessageParam `json:"messages"`
|
||||||
|
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
TopK *int `json:"top_k,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||||
|
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||||
|
Metadata *Metadata `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageParam represents a message in the request
|
||||||
|
type MessageParam struct {
|
||||||
|
Role string `json:"role"` // "user" or "assistant"
|
||||||
|
Content any `json:"content"` // string or []ContentBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlock represents a content block in a message.
|
||||||
|
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||||
|
// only when set, which is required for SDK streaming accumulation.
|
||||||
|
type ContentBlock struct {
|
||||||
|
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||||
|
|
||||||
|
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||||
|
Text *string `json:"text,omitempty"`
|
||||||
|
|
||||||
|
// For image blocks
|
||||||
|
Source *ImageSource `json:"source,omitempty"`
|
||||||
|
|
||||||
|
// For tool_use blocks
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
|
||||||
|
// For tool_result blocks
|
||||||
|
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||||
|
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||||
|
IsError bool `json:"is_error,omitempty"`
|
||||||
|
|
||||||
|
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageSource represents the source of an image
|
||||||
|
type ImageSource struct {
|
||||||
|
Type string `json:"type"` // "base64" or "url"
|
||||||
|
MediaType string `json:"media_type,omitempty"`
|
||||||
|
Data string `json:"data,omitempty"`
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool represents a tool definition
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolChoice controls how the model uses tools
|
||||||
|
type ToolChoice struct {
|
||||||
|
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThinkingConfig controls extended thinking
|
||||||
|
type ThinkingConfig struct {
|
||||||
|
Type string `json:"type"` // "enabled" or "disabled"
|
||||||
|
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata for the request
|
||||||
|
type Metadata struct {
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response types
|
||||||
|
|
||||||
|
// MessagesResponse represents an Anthropic Messages API response
|
||||||
|
type MessagesResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"` // "message"
|
||||||
|
Role string `json:"role"` // "assistant"
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content []ContentBlock `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
|
StopSequence string `json:"stop_sequence,omitempty"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage contains token usage information
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming event types
|
||||||
|
|
||||||
|
// MessageStartEvent is sent at the start of streaming
|
||||||
|
type MessageStartEvent struct {
|
||||||
|
Type string `json:"type"` // "message_start"
|
||||||
|
Message MessagesResponse `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlockStartEvent signals the start of a content block
|
||||||
|
type ContentBlockStartEvent struct {
|
||||||
|
Type string `json:"type"` // "content_block_start"
|
||||||
|
Index int `json:"index"`
|
||||||
|
ContentBlock ContentBlock `json:"content_block"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlockDeltaEvent contains incremental content updates
|
||||||
|
type ContentBlockDeltaEvent struct {
|
||||||
|
Type string `json:"type"` // "content_block_delta"
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta Delta `json:"delta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delta represents an incremental update
|
||||||
|
type Delta struct {
|
||||||
|
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
PartialJSON string `json:"partial_json,omitempty"`
|
||||||
|
Thinking string `json:"thinking,omitempty"`
|
||||||
|
Signature string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlockStopEvent signals the end of a content block
|
||||||
|
type ContentBlockStopEvent struct {
|
||||||
|
Type string `json:"type"` // "content_block_stop"
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageDeltaEvent contains updates to the message
|
||||||
|
type MessageDeltaEvent struct {
|
||||||
|
Type string `json:"type"` // "message_delta"
|
||||||
|
Delta MessageDelta `json:"delta"`
|
||||||
|
Usage DeltaUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageDelta contains stop information
|
||||||
|
type MessageDelta struct {
|
||||||
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
|
StopSequence string `json:"stop_sequence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeltaUsage contains cumulative token usage
|
||||||
|
type DeltaUsage struct {
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageStopEvent signals the end of the message
|
||||||
|
type MessageStopEvent struct {
|
||||||
|
Type string `json:"type"` // "message_stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingEvent is a keepalive event
|
||||||
|
type PingEvent struct {
|
||||||
|
Type string `json:"type"` // "ping"
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamErrorEvent is an error during streaming
|
||||||
|
type StreamErrorEvent struct {
|
||||||
|
Type string `json:"type"` // "error"
|
||||||
|
Error Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||||
|
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||||
|
var messages []api.Message
|
||||||
|
|
||||||
|
if r.System != nil {
|
||||||
|
switch sys := r.System.(type) {
|
||||||
|
case string:
|
||||||
|
if sys != "" {
|
||||||
|
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
// System can be an array of content blocks
|
||||||
|
var content strings.Builder
|
||||||
|
for _, block := range sys {
|
||||||
|
if blockMap, ok := block.(map[string]any); ok {
|
||||||
|
if blockMap["type"] == "text" {
|
||||||
|
if text, ok := blockMap["text"].(string); ok {
|
||||||
|
content.WriteString(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content.Len() > 0 {
|
||||||
|
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range r.Messages {
|
||||||
|
converted, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages = append(messages, converted...)
|
||||||
|
}
|
||||||
|
|
||||||
|
options := make(map[string]any)
|
||||||
|
|
||||||
|
options["num_predict"] = r.MaxTokens
|
||||||
|
|
||||||
|
if r.Temperature != nil {
|
||||||
|
options["temperature"] = *r.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TopP != nil {
|
||||||
|
options["top_p"] = *r.TopP
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TopK != nil {
|
||||||
|
options["top_k"] = *r.TopK
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.StopSequences) > 0 {
|
||||||
|
options["stop"] = r.StopSequences
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools api.Tools
|
||||||
|
for _, t := range r.Tools {
|
||||||
|
tool, err := convertTool(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tools = append(tools, tool)
|
||||||
|
}
|
||||||
|
|
||||||
|
var think *api.ThinkValue
|
||||||
|
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||||
|
think = &api.ThinkValue{Value: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := r.Stream
|
||||||
|
|
||||||
|
return &api.ChatRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Options: options,
|
||||||
|
Stream: &stream,
|
||||||
|
Tools: tools,
|
||||||
|
Think: think,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||||
|
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
|
var messages []api.Message
|
||||||
|
role := strings.ToLower(msg.Role)
|
||||||
|
|
||||||
|
switch content := msg.Content.(type) {
|
||||||
|
case string:
|
||||||
|
messages = append(messages, api.Message{Role: role, Content: content})
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
var textContent strings.Builder
|
||||||
|
var images []api.ImageData
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
var thinking string
|
||||||
|
var toolResults []api.Message
|
||||||
|
|
||||||
|
for _, block := range content {
|
||||||
|
blockMap, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid content block format")
|
||||||
|
}
|
||||||
|
|
||||||
|
blockType, _ := blockMap["type"].(string)
|
||||||
|
|
||||||
|
switch blockType {
|
||||||
|
case "text":
|
||||||
|
if text, ok := blockMap["text"].(string); ok {
|
||||||
|
textContent.WriteString(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "image":
|
||||||
|
source, ok := blockMap["source"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid image source")
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceType, _ := source["type"].(string)
|
||||||
|
if sourceType == "base64" {
|
||||||
|
data, _ := source["data"].(string)
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||||
|
}
|
||||||
|
images = append(images, decoded)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||||
|
}
|
||||||
|
// URL images would need to be fetched - skip for now
|
||||||
|
|
||||||
|
case "tool_use":
|
||||||
|
id, ok := blockMap["id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("tool_use block missing required 'id' field")
|
||||||
|
}
|
||||||
|
name, ok := blockMap["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("tool_use block missing required 'name' field")
|
||||||
|
}
|
||||||
|
tc := api.ToolCall{
|
||||||
|
ID: id,
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||||
|
tc.Function.Arguments = mapToArgs(input)
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
|
||||||
|
case "tool_result":
|
||||||
|
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||||
|
var resultContent string
|
||||||
|
|
||||||
|
switch c := blockMap["content"].(type) {
|
||||||
|
case string:
|
||||||
|
resultContent = c
|
||||||
|
case []any:
|
||||||
|
for _, cb := range c {
|
||||||
|
if cbMap, ok := cb.(map[string]any); ok {
|
||||||
|
if cbMap["type"] == "text" {
|
||||||
|
if text, ok := cbMap["text"].(string); ok {
|
||||||
|
resultContent += text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: resultContent,
|
||||||
|
ToolCallID: toolUseID,
|
||||||
|
})
|
||||||
|
|
||||||
|
case "thinking":
|
||||||
|
if t, ok := blockMap["thinking"].(string); ok {
|
||||||
|
thinking = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||||
|
m := api.Message{
|
||||||
|
Role: role,
|
||||||
|
Content: textContent.String(),
|
||||||
|
Images: images,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
Thinking: thinking,
|
||||||
|
}
|
||||||
|
messages = append(messages, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool results as separate messages
|
||||||
|
messages = append(messages, toolResults...)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||||
|
func convertTool(t Tool) (api.Tool, error) {
|
||||||
|
var params api.ToolFunctionParameters
|
||||||
|
if len(t.InputSchema) > 0 {
|
||||||
|
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||||
|
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: t.Name,
|
||||||
|
Description: t.Description,
|
||||||
|
Parameters: params,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||||
|
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||||
|
var content []ContentBlock
|
||||||
|
|
||||||
|
if r.Message.Thinking != "" {
|
||||||
|
content = append(content, ContentBlock{
|
||||||
|
Type: "thinking",
|
||||||
|
Thinking: ptr(r.Message.Thinking),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Message.Content != "" {
|
||||||
|
content = append(content, ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr(r.Message.Content),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range r.Message.ToolCalls {
|
||||||
|
content = append(content, ContentBlock{
|
||||||
|
Type: "tool_use",
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Input: tc.Function.Arguments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||||
|
|
||||||
|
return MessagesResponse{
|
||||||
|
ID: id,
|
||||||
|
Type: "message",
|
||||||
|
Role: "assistant",
|
||||||
|
Model: r.Model,
|
||||||
|
Content: content,
|
||||||
|
StopReason: stopReason,
|
||||||
|
Usage: Usage{
|
||||||
|
InputTokens: r.Metrics.PromptEvalCount,
|
||||||
|
OutputTokens: r.Metrics.EvalCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||||
|
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||||
|
if hasToolCalls {
|
||||||
|
return "tool_use"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reason {
|
||||||
|
case "stop":
|
||||||
|
return "end_turn"
|
||||||
|
case "length":
|
||||||
|
return "max_tokens"
|
||||||
|
default:
|
||||||
|
if reason != "" {
|
||||||
|
return "stop_sequence"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||||
|
type StreamConverter struct {
|
||||||
|
ID string
|
||||||
|
Model string
|
||||||
|
firstWrite bool
|
||||||
|
contentIndex int
|
||||||
|
inputTokens int
|
||||||
|
outputTokens int
|
||||||
|
thinkingStarted bool
|
||||||
|
thinkingDone bool
|
||||||
|
textStarted bool
|
||||||
|
toolCallsSent map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStreamConverter(id, model string) *StreamConverter {
|
||||||
|
return &StreamConverter{
|
||||||
|
ID: id,
|
||||||
|
Model: model,
|
||||||
|
firstWrite: true,
|
||||||
|
toolCallsSent: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamEvent represents a streaming event to be sent to the client
|
||||||
|
type StreamEvent struct {
|
||||||
|
Event string
|
||||||
|
Data any
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||||
|
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||||
|
var events []StreamEvent
|
||||||
|
|
||||||
|
if c.firstWrite {
|
||||||
|
c.firstWrite = false
|
||||||
|
c.inputTokens = r.Metrics.PromptEvalCount
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "message_start",
|
||||||
|
Data: MessageStartEvent{
|
||||||
|
Type: "message_start",
|
||||||
|
Message: MessagesResponse{
|
||||||
|
ID: c.ID,
|
||||||
|
Type: "message",
|
||||||
|
Role: "assistant",
|
||||||
|
Model: c.Model,
|
||||||
|
Content: []ContentBlock{},
|
||||||
|
Usage: Usage{
|
||||||
|
InputTokens: c.inputTokens,
|
||||||
|
OutputTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||||
|
if !c.thinkingStarted {
|
||||||
|
c.thinkingStarted = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_start",
|
||||||
|
Data: ContentBlockStartEvent{
|
||||||
|
Type: "content_block_start",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
ContentBlock: ContentBlock{
|
||||||
|
Type: "thinking",
|
||||||
|
Thinking: ptr(""),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_delta",
|
||||||
|
Data: ContentBlockDeltaEvent{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
Delta: Delta{
|
||||||
|
Type: "thinking_delta",
|
||||||
|
Thinking: r.Message.Thinking,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Message.Content != "" {
|
||||||
|
if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
c.thinkingDone = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.textStarted {
|
||||||
|
c.textStarted = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_start",
|
||||||
|
Data: ContentBlockStartEvent{
|
||||||
|
Type: "content_block_start",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
ContentBlock: ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr(""),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_delta",
|
||||||
|
Data: ContentBlockDeltaEvent{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
Delta: Delta{
|
||||||
|
Type: "text_delta",
|
||||||
|
Text: r.Message.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range r.Message.ToolCalls {
|
||||||
|
if c.toolCallsSent[tc.ID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.textStarted {
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
c.textStarted = false
|
||||||
|
}
|
||||||
|
|
||||||
|
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_start",
|
||||||
|
Data: ContentBlockStartEvent{
|
||||||
|
Type: "content_block_start",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
ContentBlock: ContentBlock{
|
||||||
|
Type: "tool_use",
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Input: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_delta",
|
||||||
|
Data: ContentBlockDeltaEvent{
|
||||||
|
Type: "content_block_delta",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
Delta: Delta{
|
||||||
|
Type: "input_json_delta",
|
||||||
|
PartialJSON: string(argsJSON),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.toolCallsSent[tc.ID] = true
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
if c.textStarted {
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.outputTokens = r.Metrics.EvalCount
|
||||||
|
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "message_delta",
|
||||||
|
Data: MessageDeltaEvent{
|
||||||
|
Type: "message_delta",
|
||||||
|
Delta: MessageDelta{
|
||||||
|
StopReason: stopReason,
|
||||||
|
},
|
||||||
|
Usage: DeltaUsage{
|
||||||
|
OutputTokens: c.outputTokens,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "message_stop",
|
||||||
|
Data: MessageStopEvent{
|
||||||
|
Type: "message_stop",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||||
|
func generateID(prefix string) string {
|
||||||
|
b := make([]byte, 12)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
// Fallback to time-based ID if crypto/rand fails
|
||||||
|
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s_%x", prefix, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateMessageID generates a unique message ID
|
||||||
|
func GenerateMessageID() string {
|
||||||
|
return generateID("msg")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptr returns a pointer to the given string value
|
||||||
|
func ptr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||||
|
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
953
anthropic/anthropic_test.go
Normal file
953
anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,953 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
)
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||||
|
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||||
|
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
System: "You are a helpful assistant.",
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||||
|
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
System: []any{
|
||||||
|
map[string]any{"type": "text", "text": "You are helpful."},
|
||||||
|
map[string]any{"type": "text", "text": " Be concise."},
|
||||||
|
},
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||||
|
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||||
|
temp := 0.7
|
||||||
|
topP := 0.9
|
||||||
|
topK := 40
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 2048,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Temperature: &temp,
|
||||||
|
TopP: &topP,
|
||||||
|
TopK: &topK,
|
||||||
|
StopSequences: []string{"\n", "END"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Options["temperature"] != 0.7 {
|
||||||
|
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||||
|
}
|
||||||
|
if result.Options["top_p"] != 0.9 {
|
||||||
|
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||||
|
}
|
||||||
|
if result.Options["top_k"] != 40 {
|
||||||
|
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||||
|
t.Errorf("stop sequences mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||||
|
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||||
|
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]any{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Messages[0].Content != "What's in this image?" {
|
||||||
|
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages[0].Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||||
|
t.Error("image data mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "What's the weather in Paris?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": map[string]any{"location": "Paris"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages[1].ToolCalls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
tc := result.Messages[1].ToolCalls[0]
|
||||||
|
if tc.ID != "call_123" {
|
||||||
|
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||||
|
}
|
||||||
|
if tc.Function.Name != "get_weather" {
|
||||||
|
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_123",
|
||||||
|
"content": "The weather in Paris is sunny, 22°C",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := result.Messages[0]
|
||||||
|
if msg.Role != "tool" {
|
||||||
|
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.ToolCallID != "call_123" {
|
||||||
|
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||||
|
}
|
||||||
|
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||||
|
t.Errorf("unexpected content: %q", msg.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Tools: []Tool{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get current weather",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
tool := result.Tools[0]
|
||||||
|
if tool.Type != "function" {
|
||||||
|
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||||
|
}
|
||||||
|
if tool.Function.Name != "get_weather" {
|
||||||
|
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||||
|
}
|
||||||
|
if tool.Function.Description != "Get current weather" {
|
||||||
|
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Think == nil {
|
||||||
|
t.Fatal("expected Think to be set")
|
||||||
|
}
|
||||||
|
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||||
|
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||||
|
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||||
|
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "Let me think about this...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantMsg := result.Messages[1]
|
||||||
|
if assistantMsg.Thinking != "Let me think about this..." {
|
||||||
|
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "get_weather",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := FromMessagesRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing tool_use id")
|
||||||
|
}
|
||||||
|
if err.Error() != "tool_use block missing required 'id' field" {
|
||||||
|
t.Errorf("unexpected error message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := FromMessagesRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing tool_use name")
|
||||||
|
}
|
||||||
|
if err.Error() != "tool_use block missing required 'name' field" {
|
||||||
|
t.Errorf("unexpected error message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Tools: []Tool{
|
||||||
|
{
|
||||||
|
Name: "bad_tool",
|
||||||
|
InputSchema: json.RawMessage(`{invalid json`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := FromMessagesRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid tool schema")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Hello there!",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
EvalCount: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ToMessagesResponse("msg_123", resp)
|
||||||
|
|
||||||
|
if result.ID != "msg_123" {
|
||||||
|
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||||
|
}
|
||||||
|
if result.Type != "message" {
|
||||||
|
t.Errorf("expected type 'message', got %q", result.Type)
|
||||||
|
}
|
||||||
|
if result.Role != "assistant" {
|
||||||
|
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||||
|
}
|
||||||
|
if len(result.Content) != 1 {
|
||||||
|
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||||
|
}
|
||||||
|
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||||
|
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||||
|
}
|
||||||
|
if result.StopReason != "end_turn" {
|
||||||
|
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||||
|
}
|
||||||
|
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||||
|
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_123",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ToMessagesResponse("msg_123", resp)
|
||||||
|
|
||||||
|
if len(result.Content) != 1 {
|
||||||
|
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||||
|
}
|
||||||
|
if result.Content[0].Type != "tool_use" {
|
||||||
|
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||||
|
}
|
||||||
|
if result.Content[0].ID != "call_123" {
|
||||||
|
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||||
|
}
|
||||||
|
if result.Content[0].Name != "get_weather" {
|
||||||
|
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||||
|
}
|
||||||
|
if result.StopReason != "tool_use" {
|
||||||
|
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "The answer is 42.",
|
||||||
|
Thinking: "Let me think about this...",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ToMessagesResponse("msg_123", resp)
|
||||||
|
|
||||||
|
if len(result.Content) != 2 {
|
||||||
|
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||||
|
}
|
||||||
|
if result.Content[0].Type != "thinking" {
|
||||||
|
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||||
|
}
|
||||||
|
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||||
|
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||||
|
}
|
||||||
|
if result.Content[1].Type != "text" {
|
||||||
|
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapStopReason(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
reason string
|
||||||
|
hasToolCalls bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"stop", false, "end_turn"},
|
||||||
|
{"length", false, "max_tokens"},
|
||||||
|
{"stop", true, "tool_use"},
|
||||||
|
{"other", false, "stop_sequence"},
|
||||||
|
{"", false, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
code int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{400, "invalid_request_error"},
|
||||||
|
{401, "authentication_error"},
|
||||||
|
{403, "permission_error"},
|
||||||
|
{404, "not_found_error"},
|
||||||
|
{429, "rate_limit_error"},
|
||||||
|
{500, "api_error"},
|
||||||
|
{503, "overloaded_error"},
|
||||||
|
{529, "overloaded_error"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := NewError(tt.code, "test message")
|
||||||
|
if result.Type != "error" {
|
||||||
|
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||||
|
}
|
||||||
|
if result.Error.Type != tt.want {
|
||||||
|
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||||
|
}
|
||||||
|
if result.Error.Message != "test message" {
|
||||||
|
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||||
|
}
|
||||||
|
if result.RequestID == "" {
|
||||||
|
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateMessageID(t *testing.T) {
|
||||||
|
id1 := GenerateMessageID()
|
||||||
|
id2 := GenerateMessageID()
|
||||||
|
|
||||||
|
if id1 == "" {
|
||||||
|
t.Error("GenerateMessageID returned empty string")
|
||||||
|
}
|
||||||
|
if id1 == id2 {
|
||||||
|
t.Error("GenerateMessageID returned duplicate IDs")
|
||||||
|
}
|
||||||
|
if len(id1) < 10 {
|
||||||
|
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||||
|
}
|
||||||
|
if id1[:4] != "msg_" {
|
||||||
|
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamConverter_Basic(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
// First chunk
|
||||||
|
resp1 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
events1 := conv.Process(resp1)
|
||||||
|
if len(events1) < 3 {
|
||||||
|
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have message_start, content_block_start, content_block_delta
|
||||||
|
if events1[0].Event != "message_start" {
|
||||||
|
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||||
|
}
|
||||||
|
if events1[1].Event != "content_block_start" {
|
||||||
|
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||||
|
}
|
||||||
|
if events1[2].Event != "content_block_delta" {
|
||||||
|
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final chunk
|
||||||
|
resp2 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: " world!",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{EvalCount: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
events2 := conv.Process(resp2)
|
||||||
|
|
||||||
|
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||||
|
hasStop := false
|
||||||
|
for _, e := range events2 {
|
||||||
|
if e.Event == "message_stop" {
|
||||||
|
hasStop = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasStop {
|
||||||
|
t.Error("expected message_stop event in final chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_123",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
hasToolStart := false
|
||||||
|
hasToolDelta := false
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "tool_use" {
|
||||||
|
hasToolStart = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e.Event == "content_block_delta" {
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||||
|
if delta.Delta.Type == "input_json_delta" {
|
||||||
|
hasToolDelta = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasToolStart {
|
||||||
|
t.Error("expected tool_use content_block_start event")
|
||||||
|
}
|
||||||
|
if !hasToolDelta {
|
||||||
|
t.Error("expected input_json_delta event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||||
|
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||||
|
// and don't cause a panic or corrupt stream
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
// Create a channel which cannot be JSON marshaled
|
||||||
|
unmarshalable := make(chan int)
|
||||||
|
badArgs := api.NewToolCallFunctionArguments()
|
||||||
|
badArgs.Set("channel", unmarshalable)
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_bad",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "bad_function",
|
||||||
|
Arguments: badArgs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not panic and should skip the unmarshalable tool call
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
// Verify no tool_use block was started (since marshal failed before block start)
|
||||||
|
hasToolStart := false
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "tool_use" {
|
||||||
|
hasToolStart = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasToolStart {
|
||||||
|
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||||
|
// Test that valid tool calls still work when mixed with invalid ones
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
unmarshalable := make(chan int)
|
||||||
|
badArgs := api.NewToolCallFunctionArguments()
|
||||||
|
badArgs.Set("channel", unmarshalable)
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_good",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "good_function",
|
||||||
|
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "call_bad",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "bad_function",
|
||||||
|
Arguments: badArgs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
// Count tool_use blocks - should only have 1 (the valid one)
|
||||||
|
toolStartCount := 0
|
||||||
|
toolDeltaCount := 0
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "tool_use" {
|
||||||
|
toolStartCount++
|
||||||
|
if start.ContentBlock.Name != "good_function" {
|
||||||
|
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e.Event == "content_block_delta" {
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||||
|
if delta.Delta.Type == "input_json_delta" {
|
||||||
|
toolDeltaCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if toolStartCount != 1 {
|
||||||
|
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||||
|
}
|
||||||
|
if toolDeltaCount != 1 {
|
||||||
|
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||||
|
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||||
|
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||||
|
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||||
|
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
block ContentBlock
|
||||||
|
wantKeys []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text block includes empty text field",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr(""),
|
||||||
|
},
|
||||||
|
wantKeys: []string{"type", "text"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking block includes empty thinking field",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "thinking",
|
||||||
|
Thinking: ptr(""),
|
||||||
|
},
|
||||||
|
wantKeys: []string{"type", "thinking"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text block with content",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr("hello"),
|
||||||
|
},
|
||||||
|
wantKeys: []string{"type", "text"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := json.Marshal(tt.block)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range tt.wantKeys {
|
||||||
|
if _, ok := result[key]; !ok {
|
||||||
|
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||||
|
// events include the required empty fields for SDK compatibility.
|
||||||
|
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||||
|
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
var foundTextStart bool
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "text" {
|
||||||
|
foundTextStart = true
|
||||||
|
// Marshal and verify the text field is present
|
||||||
|
data, _ := json.Marshal(start)
|
||||||
|
var result map[string]any
|
||||||
|
json.Unmarshal(data, &result)
|
||||||
|
cb := result["content_block"].(map[string]any)
|
||||||
|
if _, ok := cb["text"]; !ok {
|
||||||
|
t.Error("content_block_start for text should include 'text' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundTextStart {
|
||||||
|
t.Error("expected text content_block_start event")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model")
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||||
|
}
|
||||||
|
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
var foundThinkingStart bool
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "thinking" {
|
||||||
|
foundThinkingStart = true
|
||||||
|
data, _ := json.Marshal(start)
|
||||||
|
var result map[string]any
|
||||||
|
json.Unmarshal(data, &result)
|
||||||
|
cb := result["content_block"].(map[string]any)
|
||||||
|
if _, ok := cb["thinking"]; !ok {
|
||||||
|
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundThinkingStart {
|
||||||
|
t.Error("expected thinking content_block_start event")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 8 * format.MegaByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf io.Reader
|
var buf io.Reader
|
||||||
@@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
|
|||||||
// Create creates a model from a [Modelfile]. fn is a progress function that
|
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||||
// behaves similarly to other methods (see [Client.Pull]).
|
// behaves similarly to other methods (see [Client.Pull]).
|
||||||
//
|
//
|
||||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
|
||||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||||
var resp ProgressResponse
|
var resp ProgressResponse
|
||||||
|
|||||||
@@ -15,19 +15,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
messages := []api.Message{
|
messages := []api.Message{
|
||||||
api.Message{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: "Provide very brief, concise responses",
|
Content: "Provide very brief, concise responses",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "Name some unusual animals",
|
Content: "Name some unusual animals",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "Monotreme, platypus, echidna",
|
Content: "Monotreme, platypus, echidna",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "which of these is the most dangerous?",
|
Content: "which of these is the most dangerous?",
|
||||||
},
|
},
|
||||||
|
|||||||
162
api/types.go
162
api/types.go
@@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -227,13 +229,79 @@ type ToolCallFunction struct {
|
|||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallFunctionArguments map[string]any
|
// ToolCallFunctionArguments holds tool call arguments in insertion order.
|
||||||
|
type ToolCallFunctionArguments struct {
|
||||||
|
om *orderedmap.Map[string, any]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
|
||||||
|
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
|
||||||
|
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key.
|
||||||
|
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return t.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a key-value pair, preserving insertion order.
|
||||||
|
func (t *ToolCallFunctionArguments) Set(key string, value any) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.om == nil {
|
||||||
|
t.om = orderedmap.New[string, any]()
|
||||||
|
}
|
||||||
|
t.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of arguments.
|
||||||
|
func (t *ToolCallFunctionArguments) Len() int {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return t.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all key-value pairs in insertion order.
|
||||||
|
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return func(yield func(string, any) bool) {}
|
||||||
|
}
|
||||||
|
return t.om.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns a regular map (order not preserved).
|
||||||
|
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.om.ToMap()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *ToolCallFunctionArguments) String() string {
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
bts, _ := json.Marshal(t)
|
if t == nil || t.om == nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
bts, _ := json.Marshal(t.om)
|
||||||
return string(bts)
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
|
||||||
|
t.om = orderedmap.New[string, any]()
|
||||||
|
return json.Unmarshal(data, t.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
|
||||||
|
if t.om == nil {
|
||||||
|
return []byte("{}"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(t.om)
|
||||||
|
}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
@@ -282,12 +350,78 @@ func (pt PropertyType) String() string {
|
|||||||
return fmt.Sprintf("%v", []string(pt))
|
return fmt.Sprintf("%v", []string(pt))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToolPropertiesMap holds tool properties in insertion order.
|
||||||
|
type ToolPropertiesMap struct {
|
||||||
|
om *orderedmap.Map[string, ToolProperty]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
|
||||||
|
func NewToolPropertiesMap() *ToolPropertiesMap {
|
||||||
|
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a property by name.
|
||||||
|
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return ToolProperty{}, false
|
||||||
|
}
|
||||||
|
return t.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a property, preserving insertion order.
|
||||||
|
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.om == nil {
|
||||||
|
t.om = orderedmap.New[string, ToolProperty]()
|
||||||
|
}
|
||||||
|
t.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of properties.
|
||||||
|
func (t *ToolPropertiesMap) Len() int {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return t.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all properties in insertion order.
|
||||||
|
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return func(yield func(string, ToolProperty) bool) {}
|
||||||
|
}
|
||||||
|
return t.om.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns a regular map (order not preserved).
|
||||||
|
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.om.ToMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
|
||||||
|
if t.om == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(t.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
|
||||||
|
t.om = orderedmap.New[string, ToolProperty]()
|
||||||
|
return json.Unmarshal(data, t.om)
|
||||||
|
}
|
||||||
|
|
||||||
type ToolProperty struct {
|
type ToolProperty struct {
|
||||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||||
Type PropertyType `json:"type,omitempty"`
|
Type PropertyType `json:"type,omitempty"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Enum []any `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
|
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||||
@@ -336,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ToolFunctionParameters struct {
|
type ToolFunctionParameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Defs any `json:"$defs,omitempty"`
|
Defs any `json:"$defs,omitempty"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
Required []string `json:"required,omitempty"`
|
Required []string `json:"required,omitempty"`
|
||||||
Properties map[string]ToolProperty `json:"properties"`
|
Properties *ToolPropertiesMap `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolFunctionParameters) String() string {
|
func (t *ToolFunctionParameters) String() string {
|
||||||
@@ -553,6 +687,9 @@ type CreateRequest struct {
|
|||||||
Renderer string `json:"renderer,omitempty"`
|
Renderer string `json:"renderer,omitempty"`
|
||||||
Parser string `json:"parser,omitempty"`
|
Parser string `json:"parser,omitempty"`
|
||||||
|
|
||||||
|
// Requires is the minimum version of Ollama required by the model.
|
||||||
|
Requires string `json:"requires,omitempty"`
|
||||||
|
|
||||||
// Info is a map of additional information for the model
|
// Info is a map of additional information for the model
|
||||||
Info map[string]any `json:"info,omitempty"`
|
Info map[string]any `json:"info,omitempty"`
|
||||||
|
|
||||||
@@ -603,6 +740,7 @@ type ShowResponse struct {
|
|||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
|
Requires string `json:"requires,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyRequest is the request passed to [Client.Copy].
|
// CopyRequest is the request passed to [Client.Copy].
|
||||||
|
|||||||
@@ -11,6 +11,24 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||||
|
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||||
|
func testArgs(m map[string]any) ToolCallFunctionArguments {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||||||
input: ToolFunctionParameters{
|
input: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"name"},
|
Required: []string{"name"},
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {Type: PropertyType{"string"}},
|
"name": {Type: PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
||||||
},
|
},
|
||||||
@@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||||||
name: "no required",
|
name: "no required",
|
||||||
input: ToolFunctionParameters{
|
input: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {Type: PropertyType{"string"}},
|
"name": {Type: PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
||||||
},
|
},
|
||||||
@@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||||||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||||
fn := ToolCallFunction{
|
fn := ToolCallFunction{
|
||||||
Name: "echo",
|
Name: "echo",
|
||||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
Arguments: testArgs(map[string]any{"message": "hi"}),
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(fn)
|
data, err := json.Marshal(fn)
|
||||||
@@ -504,6 +522,116 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected ToolProperty
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nested object properties",
|
||||||
|
input: `{
|
||||||
|
"type": "object",
|
||||||
|
"description": "Location details",
|
||||||
|
"properties": {
|
||||||
|
"address": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Street address"
|
||||||
|
},
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expected: ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Location details",
|
||||||
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
|
"address": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "Street address",
|
||||||
|
},
|
||||||
|
"city": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "City name",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deeply nested properties",
|
||||||
|
input: `{
|
||||||
|
"type": "object",
|
||||||
|
"description": "Event",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Location",
|
||||||
|
"properties": {
|
||||||
|
"coordinates": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "GPS coordinates",
|
||||||
|
"properties": {
|
||||||
|
"lat": {"type": "number", "description": "Latitude"},
|
||||||
|
"lng": {"type": "number", "description": "Longitude"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expected: ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Event",
|
||||||
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
|
"location": {
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Location",
|
||||||
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
|
"coordinates": {
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "GPS coordinates",
|
||||||
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
|
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||||
|
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var prop ToolProperty
|
||||||
|
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare JSON representations since pointer comparison doesn't work
|
||||||
|
expectedJSON, err := json.Marshal(tt.expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
actualJSON, err := json.Marshal(prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||||
|
|
||||||
|
// Round-trip test: marshal and unmarshal again
|
||||||
|
data, err := json.Marshal(prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var prop2 ToolProperty
|
||||||
|
err = json.Unmarshal(data, &prop2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prop2JSON, err := json.Marshal(prop2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestToolFunctionParameters_String(t *testing.T) {
|
func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -515,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||||||
params: ToolFunctionParameters{
|
params: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"name"},
|
Required: []string{"name"},
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {
|
"name": {
|
||||||
Type: PropertyType{"string"},
|
Type: PropertyType{"string"},
|
||||||
Description: "The name of the person",
|
Description: "The name of the person",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||||
},
|
},
|
||||||
@@ -537,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||||||
s.Self = s
|
s.Self = s
|
||||||
return s
|
return s
|
||||||
}(),
|
}(),
|
||||||
Properties: map[string]ToolProperty{},
|
Properties: testPropsMap(map[string]ToolProperty{}),
|
||||||
},
|
},
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
@@ -550,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
|
||||||
|
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("zebra", "z")
|
||||||
|
args.Set("apple", "a")
|
||||||
|
args.Set("mango", "m")
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should preserve insertion order, not alphabetical
|
||||||
|
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||||
|
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range args.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round trip preserves order", func(t *testing.T) {
|
||||||
|
original := `{"z":1,"a":2,"m":3,"b":4}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(original), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, original, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("String method returns ordered JSON", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("c", 3)
|
||||||
|
args.Set("a", 1)
|
||||||
|
args.Set("b", 2)
|
||||||
|
|
||||||
|
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("key1", "value1")
|
||||||
|
args.Set("key2", 42)
|
||||||
|
|
||||||
|
v, ok := args.Get("key1")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "value1", v)
|
||||||
|
|
||||||
|
v, ok = args.Get("key2")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, 42, v)
|
||||||
|
|
||||||
|
_, ok = args.Get("nonexistent")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Len returns correct count", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
assert.Equal(t, 0, args.Len())
|
||||||
|
|
||||||
|
args.Set("a", 1)
|
||||||
|
assert.Equal(t, 1, args.Len())
|
||||||
|
|
||||||
|
args.Set("b", 2)
|
||||||
|
assert.Equal(t, 2, args.Len())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty args marshal to empty object", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, `{}`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zero value args marshal to empty object", func(t *testing.T) {
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
assert.Equal(t, "{}", args.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
|
||||||
|
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should preserve insertion order, not alphabetical
|
||||||
|
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||||
|
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||||
|
|
||||||
|
var props ToolPropertiesMap
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range props.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round trip preserves order", func(t *testing.T) {
|
||||||
|
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
|
||||||
|
|
||||||
|
var props ToolPropertiesMap
|
||||||
|
err := json.Unmarshal([]byte(original), &props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, original, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
|
||||||
|
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
|
||||||
|
|
||||||
|
v, ok := props.Get("name")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "The name", v.Description)
|
||||||
|
|
||||||
|
v, ok = props.Get("age")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "The age", v.Description)
|
||||||
|
|
||||||
|
_, ok = props.Get("nonexistent")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Len returns correct count", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
assert.Equal(t, 0, props.Len())
|
||||||
|
|
||||||
|
props.Set("a", ToolProperty{})
|
||||||
|
assert.Equal(t, 1, props.Len())
|
||||||
|
|
||||||
|
props.Set("b", ToolProperty{})
|
||||||
|
assert.Equal(t, 2, props.Len())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil props marshal to null", func(t *testing.T) {
|
||||||
|
var props *ToolPropertiesMap
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, `null`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ToMap returns regular map", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
|
||||||
|
m := props.ToMap()
|
||||||
|
assert.Equal(t, 2, len(m))
|
||||||
|
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
|
||||||
|
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
|
||||||
|
t.Run("nested objects preserve order", func(t *testing.T) {
|
||||||
|
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Outer keys should be in order
|
||||||
|
var keys []string
|
||||||
|
for k := range args.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"outer", "simple"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("arrays as values", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("items", []string{"a", "b", "c"})
|
||||||
|
args.Set("numbers", []int{1, 2, 3})
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
|
||||||
|
t.Run("nested properties preserve order", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
|
||||||
|
nestedProps := NewToolPropertiesMap()
|
||||||
|
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
|
||||||
|
props.Set("outer", ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Properties: nestedProps,
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Both outer and inner should preserve order
|
||||||
|
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -273,10 +273,6 @@ func main() {
|
|||||||
Handler: uiServer.Handler(),
|
Handler: uiServer.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := uiServer.UserData(ctx); err != nil {
|
|
||||||
slog.Warn("failed to load user data", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the UI server
|
// Start the UI server
|
||||||
slog.Info("starting ui server", "port", port)
|
slog.Info("starting ui server", "port", port)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -320,6 +316,17 @@ func main() {
|
|||||||
slog.Debug("no URL scheme request to handle")
|
slog.Debug("no URL scheme request to handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
slog.Debug("waiting for ollama server to be ready")
|
||||||
|
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
|
||||||
|
slog.Warn("ollama server not ready, continuing anyway", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := uiServer.UserData(ctx); err != nil {
|
||||||
|
slog.Warn("failed to load user data", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
osRun(cancel, hasCompletedFirstRun, startHidden)
|
osRun(cancel, hasCompletedFirstRun, startHidden)
|
||||||
|
|
||||||
slog.Info("shutting down desktop server")
|
slog.Info("shutting down desktop server")
|
||||||
@@ -361,7 +368,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
|
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("failed to call local auth endpoint", "error", err)
|
slog.Debug("failed to call local auth endpoint", "error", err)
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -191,13 +191,6 @@ func LaunchNewApp() {
|
|||||||
C.launchApp(appName)
|
C.launchApp(appName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a request to the main app thread to load a UI page
|
|
||||||
func sendUIRequestMessage(path string) {
|
|
||||||
p := C.CString(path)
|
|
||||||
defer C.free(unsafe.Pointer(p))
|
|
||||||
C.uiRequest(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
||||||
// Remove any stale Login Item registrations
|
// Remove any stale Login Item registrations
|
||||||
C.unregisterSelfFromLoginItem()
|
C.unregisterSelfFromLoginItem()
|
||||||
|
|||||||
@@ -263,11 +263,6 @@ func createLoginShortcut() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a request to the main app thread to load a UI page
|
|
||||||
func sendUIRequestMessage(path string) {
|
|
||||||
wintray.SendUIRequestMessage(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LaunchNewApp() {
|
func LaunchNewApp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NSArray* urls = [panel URLs];
|
NSArray* urls = [panel URLs];
|
||||||
if(self->params->allowMultiple && [urls count] >= 1) {
|
if([urls count] == 0) {
|
||||||
|
return DLG_CANCEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(self->params->allowMultiple) {
|
||||||
// For multiple files, we need to return all paths separated by null bytes
|
// For multiple files, we need to return all paths separated by null bytes
|
||||||
char* bufPtr = self->params->buf;
|
char* bufPtr = self->params->buf;
|
||||||
int remainingBuf = self->params->nbuf;
|
int remainingBuf = self->params->nbuf;
|
||||||
|
|
||||||
// Calculate total required buffer size first
|
// Calculate total required buffer size first
|
||||||
int totalSize = 0;
|
int totalSize = 0;
|
||||||
for(NSURL* url in urls) {
|
for(NSURL* url in urls) {
|
||||||
char tempBuf[PATH_MAX];
|
char tempBuf[PATH_MAX];
|
||||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||||
return DLG_URLFAIL;
|
return DLG_URLFAIL;
|
||||||
}
|
}
|
||||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||||
}
|
}
|
||||||
totalSize += 1; // Final null terminator
|
totalSize += 1; // Final null terminator
|
||||||
|
|
||||||
if(totalSize > self->params->nbuf) {
|
if(totalSize > self->params->nbuf) {
|
||||||
// Not enough buffer space
|
// Not enough buffer space
|
||||||
return DLG_URLFAIL;
|
return DLG_URLFAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now actually copy the paths (we know we have space)
|
// Now actually copy the paths (we know we have space)
|
||||||
bufPtr = self->params->buf;
|
bufPtr = self->params->buf;
|
||||||
for(NSURL* url in urls) {
|
for(NSURL* url in urls) {
|
||||||
char tempBuf[PATH_MAX];
|
char tempBuf[PATH_MAX];
|
||||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||||
int pathLen = strlen(tempBuf);
|
int pathLen = strlen(tempBuf);
|
||||||
strcpy(bufPtr, tempBuf);
|
strcpy(bufPtr, tempBuf);
|
||||||
bufPtr += pathLen + 1;
|
bufPtr += pathLen + 1;
|
||||||
}
|
}
|
||||||
*bufPtr = '\0'; // Final null terminator
|
*bufPtr = '\0'; // Final null terminator
|
||||||
|
} else {
|
||||||
|
// Single file/directory selection - write path to buffer
|
||||||
|
NSURL* url = [urls firstObject];
|
||||||
|
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
|
||||||
|
return DLG_URLFAIL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return DLG_OK;
|
return DLG_OK;
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
|
|||||||
type WinDlgError int
|
type WinDlgError int
|
||||||
|
|
||||||
func (e WinDlgError) Error() string {
|
func (e WinDlgError) Error() string {
|
||||||
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
|
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
func err() error {
|
func err() error {
|
||||||
|
|||||||
@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
|||||||
if _, err := os.Stat(settings.Models); err == nil {
|
if _, err := os.Stat(settings.Models); err == nil {
|
||||||
env["OLLAMA_MODELS"] = settings.Models
|
env["OLLAMA_MODELS"] = settings.Models
|
||||||
} else {
|
} else {
|
||||||
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
|
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
|
||||||
settings.Models = ""
|
|
||||||
s.store.SetSettings(settings)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if settings.ContextLength > 0 {
|
if settings.ContextLength > 0 {
|
||||||
|
|||||||
@@ -469,26 +469,24 @@ export class HealthResponse {
|
|||||||
}
|
}
|
||||||
export class User {
|
export class User {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
|
||||||
email: string;
|
email: string;
|
||||||
avatarURL: string;
|
name: string;
|
||||||
plan: string;
|
bio?: string;
|
||||||
bio: string;
|
avatarurl?: string;
|
||||||
firstName: string;
|
firstname?: string;
|
||||||
lastName: string;
|
lastname?: string;
|
||||||
overThreshold: boolean;
|
plan?: string;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
this.id = source["id"];
|
this.id = source["id"];
|
||||||
this.name = source["name"];
|
|
||||||
this.email = source["email"];
|
this.email = source["email"];
|
||||||
this.avatarURL = source["avatarURL"];
|
this.name = source["name"];
|
||||||
this.plan = source["plan"];
|
|
||||||
this.bio = source["bio"];
|
this.bio = source["bio"];
|
||||||
this.firstName = source["firstName"];
|
this.avatarurl = source["avatarurl"];
|
||||||
this.lastName = source["lastName"];
|
this.firstname = source["firstname"];
|
||||||
this.overThreshold = source["overThreshold"];
|
this.lastname = source["lastname"];
|
||||||
|
this.plan = source["plan"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class Attachment {
|
export class Attachment {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import {
|
|||||||
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
||||||
import { ollamaClient as ollama } from "./lib/ollama-client";
|
import { ollamaClient as ollama } from "./lib/ollama-client";
|
||||||
import type { ModelResponse } from "ollama/browser";
|
import type { ModelResponse } from "ollama/browser";
|
||||||
import { API_BASE } from "./lib/config";
|
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
|
||||||
|
|
||||||
// Extend Model class with utility methods
|
// Extend Model class with utility methods
|
||||||
declare module "@/gotypes" {
|
declare module "@/gotypes" {
|
||||||
@@ -27,7 +27,6 @@ declare module "@/gotypes" {
|
|||||||
Model.prototype.isCloud = function (): boolean {
|
Model.prototype.isCloud = function (): boolean {
|
||||||
return this.model.endsWith("cloud");
|
return this.model.endsWith("cloud");
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function to convert Uint8Array to base64
|
// Helper function to convert Uint8Array to base64
|
||||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||||
@@ -42,44 +41,50 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchUser(): Promise<User | null> {
|
export async function fetchUser(): Promise<User | null> {
|
||||||
try {
|
const response = await fetch(`${API_BASE}/api/me`, {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/me`, {
|
method: "POST",
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (response.ok) {
|
|
||||||
const userData: User = await response.json();
|
|
||||||
return userData;
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error fetching user:", error);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function fetchConnectUrl(): Promise<string> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/connect`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (response.ok) {
|
||||||
throw new Error("Failed to fetch connect URL");
|
const userData: User = await response.json();
|
||||||
|
|
||||||
|
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
|
||||||
|
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return userData;
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
if (response.status === 401 || response.status === 403) {
|
||||||
return data.connect_url;
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Failed to fetch user: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchConnectUrl(): Promise<string> {
|
||||||
|
const response = await fetch(`${API_BASE}/api/me`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status === 401) {
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.signin_url) {
|
||||||
|
return data.signin_url;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error("Failed to fetch connect URL");
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function disconnectUser(): Promise<void> {
|
export async function disconnectUser(): Promise<void> {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
|
const response = await fetch(`${API_BASE}/api/signout`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -204,12 +209,10 @@ export async function* sendMessage(
|
|||||||
data: uint8ArrayToBase64(att.data),
|
data: uint8ArrayToBase64(att.data),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Only send think parameter when actually requesting thinking
|
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
|
||||||
// Don't send false as it causes issues with some providers
|
|
||||||
const shouldSendThink =
|
const shouldSendThink =
|
||||||
think !== undefined &&
|
think !== undefined &&
|
||||||
((typeof think === "boolean" && think) ||
|
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
|
||||||
(typeof think === "string" && think !== ""));
|
|
||||||
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -391,7 +394,8 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
|||||||
|
|
||||||
export async function fetchHealth(): Promise<boolean> {
|
export async function fetchHealth(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/health`, {
|
// Use the /api/version endpoint as a health check
|
||||||
|
const response = await fetch(`${API_BASE}/api/version`, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -400,7 +404,8 @@ export async function fetchHealth(): Promise<boolean> {
|
|||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
return data.healthy || false;
|
// If we get a version back, the server is healthy
|
||||||
|
return !!data.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -299,9 +299,9 @@ export default function Settings() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{user?.avatarURL && (
|
{user?.avatarurl && (
|
||||||
<img
|
<img
|
||||||
src={user.avatarURL}
|
src={user.avatarurl}
|
||||||
alt={user?.name}
|
alt={user?.name}
|
||||||
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
||||||
onError={(e) => {
|
onError={(e) => {
|
||||||
|
|||||||
@@ -50,21 +50,33 @@ export default function Thinking({
|
|||||||
// Position content to show bottom when collapsed
|
// Position content to show bottom when collapsed
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
||||||
const contentHeight = contentRef.current.scrollHeight;
|
requestAnimationFrame(() => {
|
||||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
if (!contentRef.current || !wrapperRef.current) return;
|
||||||
if (contentHeight > wrapperHeight) {
|
|
||||||
const translateY = -(contentHeight - wrapperHeight);
|
const contentHeight = contentRef.current.scrollHeight;
|
||||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||||
setHasOverflow(true);
|
if (contentHeight > wrapperHeight) {
|
||||||
} else {
|
const translateY = -(contentHeight - wrapperHeight);
|
||||||
setHasOverflow(false);
|
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||||
}
|
setHasOverflow(true);
|
||||||
|
} else {
|
||||||
|
contentRef.current.style.transform = "translateY(0)";
|
||||||
|
setHasOverflow(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
} else if (contentRef.current) {
|
} else if (contentRef.current) {
|
||||||
contentRef.current.style.transform = "translateY(0)";
|
contentRef.current.style.transform = "translateY(0)";
|
||||||
setHasOverflow(false);
|
setHasOverflow(false);
|
||||||
}
|
}
|
||||||
}, [thinking, isCollapsed]);
|
}, [thinking, isCollapsed]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (activelyThinking && wrapperRef.current && !isCollapsed) {
|
||||||
|
// When expanded and actively thinking, scroll to bottom
|
||||||
|
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
|
||||||
|
}
|
||||||
|
}, [thinking, activelyThinking, isCollapsed]);
|
||||||
|
|
||||||
const handleToggle = () => {
|
const handleToggle = () => {
|
||||||
setIsCollapsed(!isCollapsed);
|
setIsCollapsed(!isCollapsed);
|
||||||
setHasUserInteracted(true);
|
setHasUserInteracted(true);
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { createQueryBatcher } from "./useQueryBatcher";
|
|||||||
import { useRefetchModels } from "./useModels";
|
import { useRefetchModels } from "./useModels";
|
||||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||||
import { useSettings } from "./useSettings";
|
import { useSettings } from "./useSettings";
|
||||||
|
import { getModelCapabilities } from "@/api";
|
||||||
|
|
||||||
export const useChats = () => {
|
export const useChats = () => {
|
||||||
return useQuery({
|
return useQuery({
|
||||||
@@ -606,6 +607,24 @@ export const useSendMessage = (chatId: string) => {
|
|||||||
queryClient.setQueryData(["staleModels"], newStaleMap);
|
queryClient.setQueryData(["staleModels"], newStaleMap);
|
||||||
|
|
||||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||||
|
|
||||||
|
// Fetch fresh capabilities for the downloaded model
|
||||||
|
getModelCapabilities(selectedModel.model)
|
||||||
|
.then((capabilities) => {
|
||||||
|
queryClient.setQueryData(
|
||||||
|
["modelCapabilities", selectedModel.model],
|
||||||
|
capabilities,
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error(
|
||||||
|
"Failed to fetch capabilities after download:",
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["modelCapabilities", selectedModel.model],
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,114 +0,0 @@
|
|||||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { pullModel } from "@/api";
|
|
||||||
import { useSelectedModel } from "./useSelectedModel";
|
|
||||||
import { useSettings } from "./useSettings";
|
|
||||||
|
|
||||||
interface DownloadProgress {
|
|
||||||
status: string;
|
|
||||||
digest?: string;
|
|
||||||
total?: number;
|
|
||||||
completed?: number;
|
|
||||||
done?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useDownloadModel(chatId?: string) {
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const { selectedModel } = useSelectedModel(chatId);
|
|
||||||
const { setSettings } = useSettings();
|
|
||||||
const [downloadProgress, setDownloadProgress] =
|
|
||||||
useState<DownloadProgress | null>(null);
|
|
||||||
const [abortController, setAbortController] =
|
|
||||||
useState<AbortController | null>(null);
|
|
||||||
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
|
|
||||||
new Set(),
|
|
||||||
);
|
|
||||||
|
|
||||||
const mutation = useMutation({
|
|
||||||
mutationFn: async (modelName: string) => {
|
|
||||||
const controller = new AbortController();
|
|
||||||
setAbortController(controller);
|
|
||||||
setDownloadProgress({ status: "Starting download..." });
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
for await (const progress of pullModel(modelName, controller.signal)) {
|
|
||||||
setDownloadProgress(progress);
|
|
||||||
|
|
||||||
if (progress.status === "success") {
|
|
||||||
// Update selected model to indicate it's now available locally
|
|
||||||
if (selectedModel && selectedModel.model === modelName) {
|
|
||||||
setSettings({ SelectedModel: modelName });
|
|
||||||
}
|
|
||||||
// Invalidate models query to refresh the list
|
|
||||||
await queryClient.invalidateQueries({ queryKey: ["models"] });
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setAbortController(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onSuccess: () => {
|
|
||||||
setDownloadProgress(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onError: (error: Error) => {
|
|
||||||
const status =
|
|
||||||
error.name === "AbortError" ? "Download cancelled" : "Download failed";
|
|
||||||
setDownloadProgress({ status, done: true });
|
|
||||||
|
|
||||||
// Clear error message after delay
|
|
||||||
const delay = error.name === "AbortError" ? 1500 : 3000;
|
|
||||||
setTimeout(() => {
|
|
||||||
setDownloadProgress(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, delay);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const cancelDownload = () => {
|
|
||||||
if (abortController) {
|
|
||||||
abortController.abort();
|
|
||||||
setAbortController(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
downloadModel: mutation.mutate,
|
|
||||||
isDownloading:
|
|
||||||
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
|
|
||||||
downloadProgress:
|
|
||||||
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
|
|
||||||
error: mutation.error,
|
|
||||||
cancelDownload,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,29 +1,20 @@
|
|||||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
||||||
|
|
||||||
export function useUser() {
|
export function useUser() {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
|
|
||||||
|
|
||||||
// Wait for initial data to be loaded
|
|
||||||
useEffect(() => {
|
|
||||||
const initialPromise = window.__initialUserDataPromise;
|
|
||||||
if (initialPromise) {
|
|
||||||
initialPromise.finally(() => {
|
|
||||||
setInitialDataLoaded(true);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
setInitialDataLoaded(true);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const userQuery = useQuery({
|
const userQuery = useQuery({
|
||||||
queryKey: ["user"],
|
queryKey: ["user"],
|
||||||
queryFn: () => fetchUser(),
|
queryFn: async () => {
|
||||||
|
const result = await fetchUser();
|
||||||
|
return result;
|
||||||
|
},
|
||||||
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
||||||
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
||||||
initialData: null, // Start with null to prevent flashing
|
retry: 10,
|
||||||
|
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
|
||||||
|
refetchOnMount: true, // Always fetch when component mounts
|
||||||
});
|
});
|
||||||
|
|
||||||
// Mutation to refresh user data
|
// Mutation to refresh user data
|
||||||
@@ -49,14 +40,15 @@ export function useUser() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const isLoading = userQuery.isLoading || userQuery.isFetching;
|
||||||
|
const isAuthenticated = Boolean(userQuery.data?.name);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
user: userQuery.data,
|
user: userQuery.data,
|
||||||
isLoading:
|
isLoading,
|
||||||
!initialDataLoaded ||
|
|
||||||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
|
|
||||||
isError: userQuery.isError,
|
isError: userQuery.isError,
|
||||||
error: userQuery.error,
|
error: userQuery.error,
|
||||||
isAuthenticated: Boolean(userQuery.data?.name),
|
isAuthenticated,
|
||||||
refreshUser: refreshUser.mutate,
|
refreshUser: refreshUser.mutate,
|
||||||
isRefreshing: refreshUser.isPending,
|
isRefreshing: refreshUser.isPending,
|
||||||
refetchUser: userQuery.refetch,
|
refetchUser: userQuery.refetch,
|
||||||
|
|||||||
@@ -8,3 +8,6 @@ export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
|
|||||||
export const OLLAMA_HOST = import.meta.env.DEV
|
export const OLLAMA_HOST = import.meta.env.DEV
|
||||||
? DEV_API_URL
|
? DEV_API_URL
|
||||||
: window.location.origin;
|
: window.location.origin;
|
||||||
|
|
||||||
|
export const OLLAMA_DOT_COM =
|
||||||
|
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
|
|||||||
"c",
|
"c",
|
||||||
"cpp",
|
"cpp",
|
||||||
"sql",
|
"sql",
|
||||||
|
"swift",
|
||||||
"yaml",
|
"yaml",
|
||||||
"markdown",
|
"markdown",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -5,13 +5,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
|||||||
import { routeTree } from "./routeTree.gen";
|
import { routeTree } from "./routeTree.gen";
|
||||||
import { fetchUser } from "./api";
|
import { fetchUser } from "./api";
|
||||||
import { StreamingProvider } from "./contexts/StreamingContext";
|
import { StreamingProvider } from "./contexts/StreamingContext";
|
||||||
import { User } from "@/gotypes";
|
|
||||||
|
|
||||||
declare global {
|
|
||||||
interface Window {
|
|
||||||
__initialUserDataPromise?: Promise<User | null>;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const queryClient = new QueryClient({
|
const queryClient = new QueryClient({
|
||||||
defaultOptions: {
|
defaultOptions: {
|
||||||
@@ -24,27 +17,11 @@ const queryClient = new QueryClient({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Track initial user data fetch
|
fetchUser().then((userData) => {
|
||||||
let initialUserDataPromise: Promise<User | null> | null = null;
|
if (userData) {
|
||||||
|
|
||||||
// Initialize user data on app startup
|
|
||||||
const initializeUserData = async () => {
|
|
||||||
try {
|
|
||||||
const userData = await fetchUser();
|
|
||||||
queryClient.setQueryData(["user"], userData);
|
queryClient.setQueryData(["user"], userData);
|
||||||
return userData;
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error initializing user data:", error);
|
|
||||||
queryClient.setQueryData(["user"], null);
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
|
|
||||||
// Start initialization immediately and track the promise
|
|
||||||
initialUserDataPromise = initializeUserData();
|
|
||||||
|
|
||||||
// Export the promise so hooks can await it
|
|
||||||
window.__initialUserDataPromise = initialUserDataPromise;
|
|
||||||
|
|
||||||
const router = createRouter({
|
const router = createRouter({
|
||||||
routeTree,
|
routeTree,
|
||||||
|
|||||||
@@ -101,15 +101,14 @@ type HealthResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Email string `json:"email"`
|
||||||
Email string `json:"email"`
|
Name string `json:"name"`
|
||||||
AvatarURL string `json:"avatarURL"`
|
Bio string `json:"bio,omitempty"`
|
||||||
Plan string `json:"plan"`
|
AvatarURL string `json:"avatarurl,omitempty"`
|
||||||
Bio string `json:"bio"`
|
FirstName string `json:"firstname,omitempty"`
|
||||||
FirstName string `json:"firstName"`
|
LastName string `json:"lastname,omitempty"`
|
||||||
LastName string `json:"lastName"`
|
Plan string `json:"plan,omitempty"`
|
||||||
OverThreshold bool `json:"overThreshold"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Attachment struct {
|
type Attachment struct {
|
||||||
|
|||||||
249
app/ui/ui.go
249
app/ui/ui.go
@@ -12,18 +12,17 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/auth"
|
|
||||||
"github.com/ollama/ollama/app/server"
|
"github.com/ollama/ollama/app/server"
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
@@ -118,40 +117,66 @@ func (s *Server) log() *slog.Logger {
|
|||||||
|
|
||||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||||
func (s *Server) ollamaProxy() http.Handler {
|
func (s *Server) ollamaProxy() http.Handler {
|
||||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
var (
|
||||||
if ollamaHost == "" {
|
proxy http.Handler
|
||||||
ollamaHost = "http://127.0.0.1:11434"
|
proxyMu sync.Mutex
|
||||||
}
|
)
|
||||||
|
|
||||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ollamaHost = "http://" + ollamaHost
|
proxyMu.Lock()
|
||||||
}
|
p := proxy
|
||||||
|
proxyMu.Unlock()
|
||||||
|
|
||||||
target, err := url.Parse(ollamaHost)
|
if p == nil {
|
||||||
if err != nil {
|
proxyMu.Lock()
|
||||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
if proxy == nil {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var err error
|
||||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
for i := range 2 {
|
||||||
})
|
if i > 0 {
|
||||||
}
|
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
err = WaitForServer(context.Background(), 10*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
if err != nil {
|
||||||
|
proxyMu.Unlock()
|
||||||
|
s.log().Error("ollama server not ready after retries", "error", err)
|
||||||
|
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
originalDirector := proxy.Director
|
target := envconfig.Host()
|
||||||
proxy.Director = func(req *http.Request) {
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
originalDirector(req)
|
|
||||||
req.Host = target.Host
|
|
||||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
|
||||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
return proxy
|
originalDirector := newProxy.Director
|
||||||
|
newProxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
req.Host = target.Host
|
||||||
|
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||||
|
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy = newProxy
|
||||||
|
p = newProxy
|
||||||
|
} else {
|
||||||
|
p = proxy
|
||||||
|
}
|
||||||
|
proxyMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||||
@@ -264,11 +289,10 @@ func (s *Server) Handler() http.Handler {
|
|||||||
ollamaProxy := s.ollamaProxy()
|
ollamaProxy := s.ollamaProxy()
|
||||||
mux.Handle("GET /api/tags", ollamaProxy)
|
mux.Handle("GET /api/tags", ollamaProxy)
|
||||||
mux.Handle("POST /api/show", ollamaProxy)
|
mux.Handle("POST /api/show", ollamaProxy)
|
||||||
|
mux.Handle("GET /api/version", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/me", handle(s.me))
|
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||||
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
|
mux.Handle("POST /api/me", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
|
mux.Handle("POST /api/signout", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/health", handle(s.health))
|
|
||||||
|
|
||||||
// React app - catch all non-API routes and serve the React app
|
// React app - catch all non-API routes and serve the React app
|
||||||
mux.Handle("GET /", s.appHandler())
|
mux.Handle("GET /", s.appHandler())
|
||||||
@@ -338,7 +362,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UserData fetches user data from ollama.com API for the current ollama key
|
// UserData fetches user data from ollama.com API for the current ollama key
|
||||||
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||||
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
||||||
@@ -349,7 +373,7 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
|||||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var user responses.User
|
var user api.UserResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
||||||
}
|
}
|
||||||
@@ -368,29 +392,27 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForServer(ctx context.Context) error {
|
// WaitForServer waits for the Ollama server to be ready
|
||||||
timeout := time.Now().Add(10 * time.Second)
|
func WaitForServer(ctx context.Context, timeout time.Duration) error {
|
||||||
// TODO: this avoids an error on first load of the app
|
deadline := time.Now().Add(timeout)
|
||||||
// however we should either show a loading state or
|
for time.Now().Before(deadline) {
|
||||||
// wait for the Ollama server to be ready before redirecting
|
|
||||||
for {
|
|
||||||
c, err := api.ClientFromEnvironment()
|
c, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := c.Version(ctx); err == nil {
|
if _, err := c.Version(ctx); err == nil {
|
||||||
break
|
slog.Debug("ollama server is ready")
|
||||||
}
|
return nil
|
||||||
if time.Now().After(timeout) {
|
|
||||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
return nil
|
return errors.New("timeout waiting for Ollama server to be ready")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
waitForServer(r.Context())
|
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
id, err := uuid.NewV7()
|
id, err := uuid.NewV7()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -975,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for _, toolCall := range res.Message.ToolCalls {
|
for _, toolCall := range res.Message.ToolCalls {
|
||||||
// continues loop as tools were executed
|
// continues loop as tools were executed
|
||||||
toolsExecuted = true
|
toolsExecuted = true
|
||||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errContent := fmt.Sprintf("Error: %v", err)
|
errContent := fmt.Sprintf("Error: %v", err)
|
||||||
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
||||||
@@ -1438,129 +1460,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := s.UserData(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
// If fetching from API fails, try to return cached user data if available
|
|
||||||
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
|
|
||||||
s.log().Info("API request failed, returning cached user data", "error", err)
|
|
||||||
responseUser := &responses.User{
|
|
||||||
Name: cachedUser.Name,
|
|
||||||
Email: cachedUser.Email,
|
|
||||||
Plan: cachedUser.Plan,
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(responseUser)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log().Error("failed to get user data", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to get user data",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Store.ClearUser(); err != nil {
|
|
||||||
s.log().Warn("failed to clear cached user data", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the SSH public key to encode for the delete request
|
|
||||||
pubKey, err := ollamaAuth.GetPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to get public key", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to get public key",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode the key using base64 URL encoding
|
|
||||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
||||||
|
|
||||||
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
|
|
||||||
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to disconnect from ollama.com",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
s.log().Error("disconnect request failed", "status", resp.StatusCode)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to disconnect from ollama.com",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to build connect URL", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to build connect URL",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"connect_url": connectURL,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
healthy := false
|
|
||||||
c, err := api.ClientFromEnvironment()
|
|
||||||
if err == nil {
|
|
||||||
if _, err := c.Version(r.Context()); err == nil {
|
|
||||||
healthy = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(responses.HealthResponse{
|
|
||||||
Healthy: healthy,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -1659,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
|||||||
|
|
||||||
tool.Function.Parameters.Type = "object"
|
tool.Function.Parameters.Type = "object"
|
||||||
tool.Function.Parameters.Required = []string{}
|
tool.Function.Parameters.Required = []string{}
|
||||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||||
|
|
||||||
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
||||||
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
||||||
|
|
||||||
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
||||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||||
|
|
||||||
for propName, propDef := range props {
|
for propName, propDef := range props {
|
||||||
if propMap, ok := propDef.(map[string]any); ok {
|
if propMap, ok := propDef.(map[string]any); ok {
|
||||||
@@ -1673,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
|||||||
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
||||||
Description: getStringFromMap(propMap, "description", ""),
|
Description: getStringFromMap(propMap, "description", ""),
|
||||||
}
|
}
|
||||||
tool.Function.Parameters.Properties[propName] = prop
|
tool.Function.Parameters.Properties.Set(propName, prop)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
|||||||
case uint32(UI_REQUEST_MSG_ID):
|
case uint32(UI_REQUEST_MSG_ID):
|
||||||
// Requests for the UI must always come from the main event thread
|
// Requests for the UI must always come from the main event thread
|
||||||
l := int(wParam)
|
l := int(wParam)
|
||||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
|
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
|
||||||
t.app.UIRun(path)
|
t.app.UIRun(path)
|
||||||
case WM_COPYDATA:
|
case WM_COPYDATA:
|
||||||
// Handle URL scheme requests from other instances
|
// Handle URL scheme requests from other instances
|
||||||
if lParam != 0 {
|
if lParam != 0 {
|
||||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
|
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
|
||||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||||
// Convert the data back to string
|
// Convert the data back to string
|
||||||
data := make([]byte, cds.CbData)
|
data := make([]byte, cds.CbData)
|
||||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
|
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
|
||||||
urlScheme := string(data)
|
urlScheme := string(data)
|
||||||
handleURLSchemeRequest(urlScheme)
|
handleURLSchemeRequest(urlScheme)
|
||||||
lResult = 1 // Return non-zero to indicate success
|
lResult = 1 // Return non-zero to indicate success
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa
|
|||||||
|
|
||||||
```
|
```
|
||||||
go build -o ollama-bench bench.go
|
go build -o ollama-bench bench.go
|
||||||
./bench -model gpt-oss:20b -epochs 6 -format csv
|
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
||||||
```
|
```
|
||||||
|
|
||||||
Using Go Run (without building)
|
Using Go Run (without building)
|
||||||
@@ -29,31 +29,32 @@ go run bench.go -model gpt-oss:20b -epochs 3
|
|||||||
### Basic Example
|
### Basic Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./bench -model gemma3 -epochs 6
|
./ollama-bench -model gemma3 -epochs 6
|
||||||
```
|
```
|
||||||
|
|
||||||
### Benchmark Multiple Models
|
### Benchmark Multiple Models
|
||||||
|
|
||||||
```
|
```
|
||||||
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||||
benchstat -col /name gemma.bench
|
benchstat -col /name gemma.bench
|
||||||
```
|
```
|
||||||
|
|
||||||
### With Image Prompt
|
### With Image Prompt
|
||||||
|
|
||||||
```
|
```
|
||||||
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Advanced Example
|
### Advanced Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Line Options
|
## Command Line Options
|
||||||
|
|
||||||
| Option | Description | Default |
|
| Option | Description | Default |
|
||||||
|
|----------|-------------|---------|
|
||||||
| -model | Comma-separated list of models to benchmark | (required) |
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
| -epochs | Number of iterations per model | 1 |
|
| -epochs | Number of iterations per model | 1 |
|
||||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
|||||||
case "benchstat":
|
case "benchstat":
|
||||||
if verbose {
|
if verbose {
|
||||||
printHeader := func() {
|
printHeader := func() {
|
||||||
fmt.Printf("sysname: %s\n", runtime.GOOS)
|
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
||||||
fmt.Printf("machine: %s\n", runtime.GOARCH)
|
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
||||||
}
|
}
|
||||||
once.Do(printHeader)
|
once.Do(printHeader)
|
||||||
}
|
}
|
||||||
@@ -147,6 +147,17 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var out io.Writer = os.Stdout
|
||||||
|
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
|
||||||
|
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
out = f
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
for range *fOpt.epochs {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
@@ -241,13 +252,14 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
if *fOpt.keepAlive > 0 {
|
if *fOpt.keepAlive > 0 {
|
||||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
42
cmd/cmd.go
42
cmd/cmd.go
@@ -45,6 +45,9 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/types/syncmap"
|
"github.com/ollama/ollama/types/syncmap"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
xcmd "github.com/ollama/ollama/x/cmd"
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||||
@@ -95,6 +98,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
|
if imagegen.IsTensorModelDir(".") {
|
||||||
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
|
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||||
|
}
|
||||||
reader = strings.NewReader("FROM .\n")
|
reader = strings.NewReader("FROM .\n")
|
||||||
} else {
|
} else {
|
||||||
return errModelfileNotFound
|
return errModelfileNotFound
|
||||||
@@ -456,6 +464,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
@@ -517,6 +526,19 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is an image generation model
|
||||||
|
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||||
|
if opts.Prompt == "" && !interactive {
|
||||||
|
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||||
|
}
|
||||||
|
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for experimental flag
|
||||||
|
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
|
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||||
|
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
var sErr api.AuthorizationError
|
var sErr api.AuthorizationError
|
||||||
@@ -543,6 +565,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use experimental agent loop with tools
|
||||||
|
if isExperimental {
|
||||||
|
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||||
|
}
|
||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
return generate(cmd, opts)
|
||||||
@@ -646,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
bar, ok := bars[resp.Digest]
|
bar, ok := bars[resp.Digest]
|
||||||
if !ok {
|
if !ok {
|
||||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
msg := resp.Status
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||||
bars[resp.Digest] = bar
|
bars[resp.Digest] = bar
|
||||||
p.Add(resp.Digest, bar)
|
p.Add(resp.Digest, bar)
|
||||||
}
|
}
|
||||||
@@ -943,6 +974,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||||
}
|
}
|
||||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||||
|
if resp.Requires != "" {
|
||||||
|
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1751,6 +1785,12 @@ func NewCLI() *cobra.Command {
|
|||||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||||
|
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||||
|
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||||
|
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||||
|
|
||||||
|
// Image generation flags (width, height, steps, seed, etc.)
|
||||||
|
imagegen.RegisterFlags(runCmd)
|
||||||
|
|
||||||
stopCmd := &cobra.Command{
|
stopCmd := &cobra.Command{
|
||||||
Use: "stop MODEL",
|
Use: "stop MODEL",
|
||||||
|
|||||||
@@ -291,6 +291,31 @@ Weigh anchor!
|
|||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("min version", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Requires: "0.14.0",
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := ` Model
|
||||||
|
architecture test
|
||||||
|
parameters 7B
|
||||||
|
quantization FP16
|
||||||
|
requires 0.14.0
|
||||||
|
|
||||||
|
`
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
@@ -1522,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShowInfoImageGen(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "ZImagePipeline",
|
||||||
|
ParameterSize: "10.3B",
|
||||||
|
QuantizationLevel: "FP8",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||||
|
Requires: "0.14.0",
|
||||||
|
}, false, &b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture ZImagePipeline \n" +
|
||||||
|
" parameters 10.3B \n" +
|
||||||
|
" quantization FP8 \n" +
|
||||||
|
" requires 0.14.0 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" image \n" +
|
||||||
|
"\n"
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushProgressMessage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status string
|
||||||
|
digest string
|
||||||
|
wantMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "uses status when provided",
|
||||||
|
status: "uploading model",
|
||||||
|
digest: "sha256:abc123456789def",
|
||||||
|
wantMsg: "uploading model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "falls back to digest when status empty",
|
||||||
|
status: "",
|
||||||
|
digest: "sha256:abc123456789def",
|
||||||
|
wantMsg: "pushing abc123456789...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles short digest gracefully",
|
||||||
|
status: "",
|
||||||
|
digest: "sha256:abc",
|
||||||
|
wantMsg: "pushing sha256:abc...",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
msg := tt.status
|
||||||
|
if msg == "" {
|
||||||
|
if len(tt.digest) >= 19 {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||||
|
} else {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg != tt.wantMsg {
|
||||||
|
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||||
// Test that modifications to original don't affect copy
|
// Test that modifications to original don't affect copy
|
||||||
originalThink := &api.ThinkValue{Value: "original"}
|
originalThink := &api.ThinkValue{Value: "original"}
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
ofs "github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,8 +21,13 @@ type ModelParameters struct {
|
|||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
|
||||||
|
// TODO is this needed?
|
||||||
|
ModelType string `json:"model_type"`
|
||||||
|
|
||||||
TextModel struct {
|
TextModel struct {
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
ModelType string `json:"model_type"`
|
||||||
} `json:"text_config"`
|
} `json:"text_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,8 +41,94 @@ type AdapterParameters struct {
|
|||||||
} `json:"lora_parameters"`
|
} `json:"lora_parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
type KV map[string]any
|
||||||
kv := ggml.KV{
|
|
||||||
|
func (kv KV) Architecture() string {
|
||||||
|
return kv.String("general.architecture", "unknown")
|
||||||
|
}
|
||||||
|
|
||||||
|
type valueTypes interface {
|
||||||
|
uint8 | int8 | uint16 | int16 |
|
||||||
|
uint32 | int32 | uint64 | int64 |
|
||||||
|
string | float32 | float64 | bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type arrayValueTypes interface {
|
||||||
|
[]uint8 | []int8 | []uint16 | []int16 |
|
||||||
|
[]uint32 | []int32 | []uint64 | []int64 |
|
||||||
|
[]string | []float32 | []float64 | []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||||
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
|
key = kv.Architecture() + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := kv[key].(T); ok {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
return defaultValue[0], false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) String(key string, defaultValue ...string) string {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||||
|
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Len() int {
|
||||||
|
return len(kv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Keys() iter.Seq[string] {
|
||||||
|
return maps.Keys(kv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Value(key string) any {
|
||||||
|
return kv[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||||
|
kv := KV{
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
"general.quantization_version": uint32(2),
|
"general.quantization_version": uint32(2),
|
||||||
"tokenizer.ggml.pre": t.Pre,
|
"tokenizer.ggml.pre": t.Pre,
|
||||||
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p AdapterParameters) KV() ggml.KV {
|
func (p AdapterParameters) KV() KV {
|
||||||
var alpha float32
|
var alpha float32
|
||||||
if p.LoraParameters.Alpha == 0 {
|
if p.LoraParameters.Alpha == 0 {
|
||||||
alpha = float32(p.Alpha)
|
alpha = float32(p.Alpha)
|
||||||
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
|
|||||||
alpha = p.LoraParameters.Alpha
|
alpha = p.LoraParameters.Alpha
|
||||||
}
|
}
|
||||||
|
|
||||||
kv := ggml.KV{
|
kv := KV{
|
||||||
"adapter.lora.alpha": alpha,
|
"adapter.lora.alpha": alpha,
|
||||||
"adapter.type": "lora",
|
"adapter.type": "lora",
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelConverter interface {
|
type ModelKV interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(*Tokenizer) ggml.KV
|
KV(*Tokenizer) KV
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelConverter interface {
|
||||||
|
ModelKV
|
||||||
|
|
||||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||||
Tensors([]Tensor) []*ggml.Tensor
|
Tensors([]Tensor) []*ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
@@ -107,7 +206,7 @@ type moreParser interface {
|
|||||||
|
|
||||||
type AdapterConverter interface {
|
type AdapterConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(ggml.KV) ggml.KV
|
KV(ofs.Config) KV
|
||||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||||
Tensors([]Tensor) []*ggml.Tensor
|
Tensors([]Tensor) []*ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
@@ -115,7 +214,7 @@ type AdapterConverter interface {
|
|||||||
Replacements() []string
|
Replacements() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
arch, ok := baseKV["general.architecture"]
|
arch := baseKV.Architecture()
|
||||||
if !ok {
|
if arch == "" {
|
||||||
return errors.New("architecture not set for the base model")
|
return errors.New("architecture not set for the base model")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
|||||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||||
// and files it finds in the input path.
|
|
||||||
// Supported input model formats include safetensors.
|
|
||||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
|
||||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
|
||||||
bts, err := fs.ReadFile(fsys, "config.json")
|
bts, err := fs.ReadFile(fsys, "config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var p ModelParameters
|
var p ModelParameters
|
||||||
if err := json.Unmarshal(bts, &p); err != nil {
|
if err := json.Unmarshal(bts, &p); err != nil {
|
||||||
return err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.Architectures) < 1 {
|
if len(p.Architectures) < 1 {
|
||||||
return errors.New("unknown architecture")
|
return nil, nil, errors.New("unknown architecture")
|
||||||
}
|
}
|
||||||
|
|
||||||
var conv ModelConverter
|
var conv ModelConverter
|
||||||
@@ -182,6 +277,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &llama4Model{}
|
conv = &llama4Model{}
|
||||||
case "Mistral3ForConditionalGeneration":
|
case "Mistral3ForConditionalGeneration":
|
||||||
conv = &mistral3Model{}
|
conv = &mistral3Model{}
|
||||||
|
case "Ministral3ForCausalLM":
|
||||||
|
conv = &mistral3CausalModel{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
@@ -200,31 +297,37 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &qwen25VLModel{}
|
conv = &qwen25VLModel{}
|
||||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||||
conv = &qwen3VLModel{}
|
conv = &qwen3VLModel{}
|
||||||
|
case "Olmo3ForCausalLM":
|
||||||
|
conv = &olmoModel{}
|
||||||
case "BertModel":
|
case "BertModel":
|
||||||
conv = &bertModel{}
|
conv = &bertModel{}
|
||||||
|
case "NomicBertModel", "NomicBertMoEModel":
|
||||||
|
conv = &nomicbertModel{}
|
||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
case "GptOssForCausalLM":
|
case "GptOssForCausalLM":
|
||||||
conv = &gptossModel{}
|
conv = &gptossModel{}
|
||||||
case "DeepseekOCRForCausalLM":
|
case "DeepseekOCRForCausalLM":
|
||||||
conv = &deepseekocr{}
|
conv = &deepseekocr{}
|
||||||
|
case "DeepseekV3ForCausalLM":
|
||||||
|
conv = &deepseek2Model{}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
return err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if t, ok := conv.(moreParser); ok {
|
if t, ok := conv.(moreParser); ok {
|
||||||
if err := t.parseMore(fsys); err != nil {
|
if err := t.parseMore(fsys); err != nil {
|
||||||
return err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||||
@@ -246,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
default:
|
default:
|
||||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||||
}
|
}
|
||||||
|
return conv, t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||||
|
// and files it finds in the input path.
|
||||||
|
// Supported input model formats include safetensors.
|
||||||
|
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||||
|
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||||
|
kv, t, err := LoadModelMetadata(fsys)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conv := kv.(ModelConverter)
|
||||||
|
|
||||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||||
slices.Reverse(ts[i].Shape)
|
slices.Reverse(ts[i].Shape)
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*commandrModel)(nil)
|
var _ ModelConverter = (*commandrModel)(nil)
|
||||||
|
|
||||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "command-r"
|
kv["general.architecture"] = "command-r"
|
||||||
kv["general.name"] = "command-r"
|
kv["general.name"] = "command-r"
|
||||||
|
|||||||
173
convert/convert_deepseek2.go
Normal file
173
convert/convert_deepseek2.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deepseek2Model struct {
|
||||||
|
ModelParameters // architectures, vocab_size
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||||
|
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||||
|
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||||
|
QLoraRank uint32 `json:"q_lora_rank"`
|
||||||
|
VHeadDim uint32 `json:"v_head_dim"`
|
||||||
|
|
||||||
|
ExpertCount uint32 `json:"n_routed_experts"`
|
||||||
|
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||||
|
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
|
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||||
|
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||||
|
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||||
|
|
||||||
|
ScoringFunc string `json:"scoring_func"`
|
||||||
|
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||||
|
|
||||||
|
RopeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
MScaleAllDim float32 `json:"mscale_all_dim"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
|
||||||
|
Architecture string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "deepseek2"
|
||||||
|
kv["general.type"] = "model"
|
||||||
|
kv["deepseek2.block_count"] = p.HiddenLayers
|
||||||
|
|
||||||
|
numHeads := p.NumAttentionHeads
|
||||||
|
numKVHeads := p.NumKeyValueHeads
|
||||||
|
|
||||||
|
kv["deepseek2.attention.head_count"] = numHeads
|
||||||
|
kv["deepseek2.attention.head_count_kv"] = numKVHeads
|
||||||
|
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||||
|
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||||
|
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||||
|
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["deepseek2.embedding_length"] = p.HiddenSize
|
||||||
|
kv["deepseek2.expert_count"] = p.ExpertCount
|
||||||
|
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||||
|
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
|
||||||
|
|
||||||
|
var scoringFunc uint32
|
||||||
|
switch p.ScoringFunc {
|
||||||
|
case "softmax":
|
||||||
|
// not currently supported in the model, but needed for Deepseek-OCR
|
||||||
|
scoringFunc = 1
|
||||||
|
case "sigmoid":
|
||||||
|
scoringFunc = 2
|
||||||
|
}
|
||||||
|
kv["deepseek2.expert_gating_func"] = scoringFunc
|
||||||
|
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
|
||||||
|
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||||
|
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
|
||||||
|
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||||
|
|
||||||
|
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||||
|
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||||
|
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
|
||||||
|
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.pre"] = "deepseek-v3"
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"language_model.", "",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||||
|
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||||
|
"self_attn.kv_b_proj", "attn_kv_b",
|
||||||
|
"self_attn.q_a_proj", "attn_q_a",
|
||||||
|
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||||
|
"self_attn.q_b_proj", "attn_q_b",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||||
|
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||||
|
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||||
|
"mlp.gate", "ffn_gate_inp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||||
|
merges := make([]merge, p.HiddenLayers*3)
|
||||||
|
for i := range p.HiddenLayers {
|
||||||
|
merges[i*3+0] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+1] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+2] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
skipLayer := func(n string, minValue uint32) bool {
|
||||||
|
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||||
|
matches := re.FindStringSubmatch(n)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
blkNum, err := strconv.Atoi(matches[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(blkNum) >= minValue
|
||||||
|
}
|
||||||
|
|
||||||
|
out, s = mergeTensors(s, merges...)
|
||||||
|
for _, t := range s {
|
||||||
|
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||||
|
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||||
|
slog.Debug("skipping layer", "name", t.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
|||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||||
kv := m.ModelParameters.KV(t)
|
kv := m.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "deepseekocr"
|
kv["general.architecture"] = "deepseekocr"
|
||||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*gemmaModel)(nil)
|
var _ ModelConverter = (*gemmaModel)(nil)
|
||||||
|
|
||||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma"
|
kv["general.architecture"] = "gemma"
|
||||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import "github.com/ollama/ollama/fs/ggml"
|
|
||||||
|
|
||||||
type gemma2Model struct {
|
type gemma2Model struct {
|
||||||
gemmaModel
|
gemmaModel
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
@@ -9,7 +7,7 @@ type gemma2Model struct {
|
|||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||||
|
|
||||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
return kv
|
return kv
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"slices"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type gemma3Model struct {
|
type gemma3Model struct {
|
||||||
@@ -26,16 +25,26 @@ type gemma3Model struct {
|
|||||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
RopeScaling *struct {
|
||||||
|
Type string `json:"rope_type"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -44,7 +53,7 @@ const (
|
|||||||
gemma27BLayerCount = 62
|
gemma27BLayerCount = 62
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma3"
|
kv["general.architecture"] = "gemma3"
|
||||||
|
|
||||||
@@ -81,9 +90,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
|
||||||
|
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||||
|
// key (an int) or as the layer_types key (a list of strings).
|
||||||
|
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||||
|
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||||
|
for i := range numBlocks {
|
||||||
|
var isLocal bool
|
||||||
|
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||||
|
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||||
|
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||||
|
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||||
|
}
|
||||||
|
if !yield(isLocal) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if p.FinalLogitSoftcap > 0 {
|
||||||
|
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||||
|
}
|
||||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||||
|
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||||
|
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||||
|
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||||
|
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||||
|
}
|
||||||
|
|
||||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
|||||||
VisionModel struct{} `json:"vision_config"`
|
VisionModel struct{} `json:"vision_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||||
kv := m.ModelParameters.KV(t)
|
kv := m.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma3n"
|
kv["general.architecture"] = "gemma3n"
|
||||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*gptossModel)(nil)
|
var _ ModelConverter = (*gptossModel)(nil)
|
||||||
|
|
||||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||||
kv := m.ModelParameters.KV(t)
|
kv := m.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gptoss"
|
kv["general.architecture"] = "gptoss"
|
||||||
kv["general.file_type"] = uint32(4)
|
kv["general.file_type"] = uint32(4)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*llamaModel)(nil)
|
var _ ModelConverter = (*llamaModel)(nil)
|
||||||
|
|
||||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.vocab_size"] = p.VocabSize
|
kv["llama.vocab_size"] = p.VocabSize
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// KV implements ModelConverter.
|
// KV implements ModelConverter.
|
||||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "llama4"
|
kv["general.architecture"] = "llama4"
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,13 +19,13 @@ type llamaAdapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||||
|
|
||||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||||
|
|
||||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,13 +30,15 @@ type mistral3Model struct {
|
|||||||
HiddenAct string `json:"hidden_act"`
|
HiddenAct string `json:"hidden_act"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
RopeParameters struct {
|
RopeParameters struct {
|
||||||
BetaFast float32 `json:"beta_fast"`
|
BetaFast float32 `json:"beta_fast"`
|
||||||
BetaSlow float32 `json:"beta_slow"`
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
Factor float32 `json:"factor"`
|
Factor float32 `json:"factor"`
|
||||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
RopeType string `json:"rope_type"`
|
RopeType string `json:"rope_type"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
Mscale *float32 `json:"mscale"`
|
||||||
|
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||||
} `json:"rope_parameters"`
|
} `json:"rope_parameters"`
|
||||||
} `json:"text_config"`
|
} `json:"text_config"`
|
||||||
VisionModel struct {
|
VisionModel struct {
|
||||||
@@ -50,12 +52,15 @@ type mistral3Model struct {
|
|||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
HiddenAct string `json:"hidden_act"`
|
HiddenAct string `json:"hidden_act"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeParameters struct {
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "mistral3"
|
kv["general.architecture"] = "mistral3"
|
||||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||||
@@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||||
|
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
|
||||||
|
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
|
||||||
|
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
|
||||||
|
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
|
||||||
|
|
||||||
|
if p.TextModel.RopeParameters.Mscale != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
|
||||||
|
}
|
||||||
|
if p.TextModel.RopeParameters.MscaleAllDim != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
|
||||||
|
}
|
||||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
}
|
||||||
|
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vision configuration
|
// Vision configuration
|
||||||
@@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
|
||||||
|
|
||||||
// Multimodal configuration
|
// Multimodal configuration
|
||||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||||
|
|||||||
181
convert/convert_mistral_causal.go
Normal file
181
convert/convert_mistral_causal.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistral3CausalModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
RopeParameters struct {
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||||
|
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
Mscale *float32 `json:"mscale"`
|
||||||
|
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral3"
|
||||||
|
kv["mistral3.vocab_size"] = p.VocabSize
|
||||||
|
|
||||||
|
// Text configuration
|
||||||
|
kv["mistral3.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["mistral3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["mistral3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||||
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["mistral3.attention.key_length"] = p.HeadDim
|
||||||
|
kv["mistral3.attention.value_length"] = p.HeadDim
|
||||||
|
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
|
||||||
|
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
|
||||||
|
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
|
||||||
|
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
|
||||||
|
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||||
|
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||||
|
|
||||||
|
if p.RopeParameters.Mscale != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.MscaleAllDim != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||||
|
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.Llama4ScalingBeta != nil {
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if !strings.HasPrefix(t.Name(), "v.") {
|
||||||
|
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"model.", "",
|
||||||
|
"layers", "blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"embed_tokens", "token_embd",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"ffn_norm", "ffn_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||||
|
heads = p.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
|||||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||||
kv := p.llamaModel.KV(t)
|
kv := p.llamaModel.KV(t)
|
||||||
|
|
||||||
if p.NumLocalExperts > 0 {
|
if p.NumLocalExperts > 0 {
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
|||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||||
kv := m.ModelParameters.KV(t)
|
kv := m.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "mllama"
|
kv["general.architecture"] = "mllama"
|
||||||
|
|
||||||
|
|||||||
213
convert/convert_nomicbert.go
Normal file
213
convert/convert_nomicbert.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/json"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nomicbertModel struct {
|
||||||
|
ModelParameters
|
||||||
|
NLayers uint32 `json:"n_layers"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeFreqBase float32 `json:"rope_theta"`
|
||||||
|
normalizeEmbeddings bool
|
||||||
|
PoolingType uint32
|
||||||
|
|
||||||
|
// MoE parameters (only present in v2 models)
|
||||||
|
NumExperts uint32 `json:"num_local_experts"`
|
||||||
|
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
|
||||||
|
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ ModelConverter = (*nomicbertModel)(nil)
|
||||||
|
_ moreParser = (*nomicbertModel)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||||
|
bts, err := fs.ReadFile(fsys, "modules.json")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modules []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(bts, &modules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pooling string
|
||||||
|
for _, m := range modules {
|
||||||
|
switch m.Type {
|
||||||
|
case "sentence_transformers.models.Pooling":
|
||||||
|
pooling = m.Path
|
||||||
|
case "sentence_transformers.models.Normalize":
|
||||||
|
p.normalizeEmbeddings = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pooling != "" {
|
||||||
|
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pc struct {
|
||||||
|
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
||||||
|
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(bts, &pc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pc.PoolingModeMeanTokens {
|
||||||
|
p.PoolingType = 1
|
||||||
|
} else if pc.PoolingModeCLSToken {
|
||||||
|
p.PoolingType = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
|
||||||
|
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||||
|
arch := "nomic-bert"
|
||||||
|
if p.MoEEveryNLayers > 0 {
|
||||||
|
arch += "-moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["general.architecture"] = arch
|
||||||
|
kv["attention.causal"] = false
|
||||||
|
kv["pooling_type"] = p.PoolingType
|
||||||
|
kv["normalize_embeddings"] = p.normalizeEmbeddings
|
||||||
|
|
||||||
|
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
|
||||||
|
|
||||||
|
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
|
||||||
|
kv["context_length"] = contextLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
|
||||||
|
kv["embedding_length"] = p.HiddenSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
|
||||||
|
kv["feed_forward_length"] = p.IntermediateSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCount := p.NumAttentionHeads; headCount > 0 {
|
||||||
|
kv["attention.head_count"] = p.NumAttentionHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
|
||||||
|
kv["attention.head_count_kv"] = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
|
||||||
|
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeFreqBase > 0 {
|
||||||
|
kv["rope.freq_base"] = p.RopeFreqBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// MoE specific parameters (only if MoE is enabled)
|
||||||
|
if p.NumExperts > 0 {
|
||||||
|
kv["expert_count"] = p.NumExperts
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.NumExpertsUsed > 0 {
|
||||||
|
kv["expert_used_count"] = p.NumExpertsUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MoEEveryNLayers > 0 {
|
||||||
|
kv["moe_every_n_layers"] = p.MoEEveryNLayers
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.model"] = "bert"
|
||||||
|
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
||||||
|
|
||||||
|
// convert to phantom space tokens
|
||||||
|
for i, e := range t.Tokens {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
|
||||||
|
// noop - keep special tokens as-is
|
||||||
|
case strings.HasPrefix(e, "##"):
|
||||||
|
t.Tokens[i] = e[2:]
|
||||||
|
default:
|
||||||
|
t.Tokens[i] = "\u2581" + e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.tokens"] = t.Tokens
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
if slices.Contains([]string{
|
||||||
|
"embeddings.position_ids",
|
||||||
|
"pooler.dense.weight",
|
||||||
|
"pooler.dense.bias",
|
||||||
|
}, t.Name()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nomicbertModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"encoder.layer", "blk",
|
||||||
|
"encoder.layers", "blk",
|
||||||
|
"embeddings.word_embeddings", "token_embd",
|
||||||
|
"embeddings.token_type_embeddings", "token_types",
|
||||||
|
"embeddings.LayerNorm", "token_embd_norm",
|
||||||
|
|
||||||
|
"attention.self.qkv", "attn_qkv",
|
||||||
|
|
||||||
|
"attention.output.dense", "attn_output",
|
||||||
|
"attention.output.LayerNorm", "attn_output_norm",
|
||||||
|
|
||||||
|
"mlp.up", "ffn_up",
|
||||||
|
"mlp.down", "ffn_down",
|
||||||
|
|
||||||
|
"mlp.router", "ffn_gate_inp",
|
||||||
|
"mlp.experts.up", "ffn_up_exps",
|
||||||
|
"mlp.experts.down", "ffn_down_exps",
|
||||||
|
|
||||||
|
"intermediate.dense", "ffn_up",
|
||||||
|
"output.dense", "ffn_down",
|
||||||
|
"output.LayerNorm", "layer_output_norm",
|
||||||
|
}
|
||||||
|
}
|
||||||
117
convert/convert_olmo.go
Normal file
117
convert/convert_olmo.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ropeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||||
|
AttentionFactor float32 `json:"attention_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmoModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*olmoModel)(nil)
|
||||||
|
|
||||||
|
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "olmo3"
|
||||||
|
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
|
||||||
|
if p.RopeTheta > 0 {
|
||||||
|
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeScaling != nil {
|
||||||
|
if p.RopeScaling.Factor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||||
|
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||||
|
}
|
||||||
|
if p.RopeScaling.AttentionFactor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.RopeType != "" {
|
||||||
|
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RMSNormEPS > 0 {
|
||||||
|
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SlidingWindow > 0 {
|
||||||
|
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.LayerTypes) > 0 {
|
||||||
|
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||||
|
for i, layerType := range p.LayerTypes {
|
||||||
|
slidingPattern[i] = (layerType == "sliding_attention")
|
||||||
|
}
|
||||||
|
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*phi3Model)(nil)
|
var _ ModelConverter = (*phi3Model)(nil)
|
||||||
|
|
||||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "phi3"
|
kv["general.architecture"] = "phi3"
|
||||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*qwen2Model)(nil)
|
var _ ModelConverter = (*qwen2Model)(nil)
|
||||||
|
|
||||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen2"
|
kv["general.architecture"] = "qwen2"
|
||||||
kv["qwen2.block_count"] = q.HiddenLayers
|
kv["qwen2.block_count"] = q.HiddenLayers
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||||
|
|
||||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen25vl"
|
kv["general.architecture"] = "qwen25vl"
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// KV implements ModelConverter.
|
// KV implements ModelConverter.
|
||||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||||
arch := "qwen3"
|
arch := "qwen3"
|
||||||
if q.NumExperts > 0 {
|
if q.NumExperts > 0 {
|
||||||
arch += "moe"
|
arch += "moe"
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
|||||||
return json.Unmarshal(bts, &m.VisionModel)
|
return json.Unmarshal(bts, &m.VisionModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||||
kv := m.qwen3Model.KV(t)
|
kv := m.qwen3Model.KV(t)
|
||||||
|
|
||||||
arch := "qwen3vl"
|
arch := "qwen3vl"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
fsc "github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ type tensorData struct {
|
|||||||
Shape []int `json:"shape"`
|
Shape []int `json:"shape"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||||
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
|||||||
return r, m.KV(), m.Tensors()
|
return r, m.KV(), m.Tensors()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||||
actual := make(map[string]string)
|
actual := make(map[string]string)
|
||||||
for k, v := range kv {
|
for k := range kv.Keys() {
|
||||||
|
v := kv.Value(k)
|
||||||
if s, ok := v.(json.Marshaler); !ok {
|
if s, ok := v.(json.Marshaler); !ok {
|
||||||
actual[k] = fmt.Sprintf("%v", v)
|
actual[k] = fmt.Sprintf("%v", v)
|
||||||
} else {
|
} else {
|
||||||
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
|||||||
func TestConvertAdapter(t *testing.T) {
|
func TestConvertAdapter(t *testing.T) {
|
||||||
type AdapterCase struct {
|
type AdapterCase struct {
|
||||||
Name string
|
Name string
|
||||||
BaseKV map[string]any
|
BaseKV KV
|
||||||
Expected map[string]string
|
Expected map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
|
|
||||||
// temporary fix to handle gemma3 broken configs
|
// temporary fix to handle gemma3 broken configs
|
||||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||||
|
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
|
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
|
||||||
devices[i].AddInitValidation(extraEnvs)
|
devices[i].AddInitValidation(extraEnvs)
|
||||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||||
slog.Debug("filtering device which didn't fully initialize",
|
slog.Debug("filtering device which didn't fully initialize",
|
||||||
@@ -333,7 +333,8 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
||||||
devFilter := ml.GetVisibleDevicesEnv(devices)
|
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
|
||||||
|
devFilter := ml.GetVisibleDevicesEnv(devices, false)
|
||||||
|
|
||||||
for dir := range libDirs {
|
for dir := range libDirs {
|
||||||
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
* [API Reference](https://docs.ollama.com/api)
|
* [API Reference](https://docs.ollama.com/api)
|
||||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||||
|
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||||
|
|
||||||
### Resources
|
### Resources
|
||||||
|
|
||||||
|
|||||||
16
docs/api.md
16
docs/api.md
@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
|||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
@@ -507,7 +507,7 @@ The `message` object has the following fields:
|
|||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_temperature",
|
"name": "get_weather",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"city": "Toronto"
|
"city": "Toronto"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": "11 degrees celsius",
|
"content": "11 degrees celsius",
|
||||||
"tool_name": "get_temperature",
|
"tool_name": "get_weather"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"stream": false,
|
"stream": false,
|
||||||
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
|||||||
- `template`: (optional) the prompt template for the model
|
- `template`: (optional) the prompt template for the model
|
||||||
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
||||||
- `system`: (optional) a string containing the system prompt for the model
|
- `system`: (optional) a string containing the system prompt for the model
|
||||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
|
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
|
||||||
- `messages`: (optional) a list of message objects used to create a conversation
|
- `messages`: (optional) a list of message objects used to create a conversation
|
||||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
||||||
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
|
|||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
- `dimensions`: number of dimensions for the embedding
|
- `dimensions`: number of dimensions for the embedding
|
||||||
|
|
||||||
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
|
|||||||
|
|
||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|||||||
406
docs/api/anthropic-compatibility.mdx
Normal file
406
docs/api/anthropic-compatibility.mdx
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
---
|
||||||
|
title: Anthropic compatibility
|
||||||
|
---
|
||||||
|
|
||||||
|
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||||
|
|
||||||
|
## Recommended models
|
||||||
|
|
||||||
|
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||||
|
|
||||||
|
Pull a model before use:
|
||||||
|
```shell
|
||||||
|
ollama pull qwen3-coder
|
||||||
|
ollama pull glm-4.7:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Environment variables
|
||||||
|
|
||||||
|
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
|
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||||
|
```
|
||||||
|
|
||||||
|
### Simple `/v1/messages` example
|
||||||
|
|
||||||
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
|
```python basic.py
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(
|
||||||
|
base_url='http://localhost:11434',
|
||||||
|
api_key='ollama', # required but ignored
|
||||||
|
)
|
||||||
|
|
||||||
|
message = client.messages.create(
|
||||||
|
model='qwen3-coder',
|
||||||
|
max_tokens=1024,
|
||||||
|
messages=[
|
||||||
|
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(message.content[0].text)
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript basic.js
|
||||||
|
import Anthropic from "@anthropic-ai/sdk";
|
||||||
|
|
||||||
|
const anthropic = new Anthropic({
|
||||||
|
baseURL: "http://localhost:11434",
|
||||||
|
apiKey: "ollama", // required but ignored
|
||||||
|
});
|
||||||
|
|
||||||
|
const message = await anthropic.messages.create({
|
||||||
|
model: "qwen3-coder",
|
||||||
|
max_tokens: 1024,
|
||||||
|
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(message.content[0].text);
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell basic.sh
|
||||||
|
curl -X POST http://localhost:11434/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "x-api-key: ollama" \
|
||||||
|
-H "anthropic-version: 2023-06-01" \
|
||||||
|
-d '{
|
||||||
|
"model": "qwen3-coder",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
### Streaming example
|
||||||
|
|
||||||
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
|
```python streaming.py
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(
|
||||||
|
base_url='http://localhost:11434',
|
||||||
|
api_key='ollama',
|
||||||
|
)
|
||||||
|
|
||||||
|
with client.messages.stream(
|
||||||
|
model='qwen3-coder',
|
||||||
|
max_tokens=1024,
|
||||||
|
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||||
|
) as stream:
|
||||||
|
for text in stream.text_stream:
|
||||||
|
print(text, end='', flush=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript streaming.js
|
||||||
|
import Anthropic from "@anthropic-ai/sdk";
|
||||||
|
|
||||||
|
const anthropic = new Anthropic({
|
||||||
|
baseURL: "http://localhost:11434",
|
||||||
|
apiKey: "ollama",
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = await anthropic.messages.stream({
|
||||||
|
model: "qwen3-coder",
|
||||||
|
max_tokens: 1024,
|
||||||
|
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const event of stream) {
|
||||||
|
if (
|
||||||
|
event.type === "content_block_delta" &&
|
||||||
|
event.delta.type === "text_delta"
|
||||||
|
) {
|
||||||
|
process.stdout.write(event.delta.text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell streaming.sh
|
||||||
|
curl -X POST http://localhost:11434/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "qwen3-coder",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"stream": true,
|
||||||
|
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
### Tool calling example
|
||||||
|
|
||||||
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
|
```python tools.py
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(
|
||||||
|
base_url='http://localhost:11434',
|
||||||
|
api_key='ollama',
|
||||||
|
)
|
||||||
|
|
||||||
|
message = client.messages.create(
|
||||||
|
model='qwen3-coder',
|
||||||
|
max_tokens=1024,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
'name': 'get_weather',
|
||||||
|
'description': 'Get the current weather in a location',
|
||||||
|
'input_schema': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'location': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'The city and state, e.g. San Francisco, CA'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'required': ['location']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in message.content:
|
||||||
|
if block.type == 'tool_use':
|
||||||
|
print(f'Tool: {block.name}')
|
||||||
|
print(f'Input: {block.input}')
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript tools.js
|
||||||
|
import Anthropic from "@anthropic-ai/sdk";
|
||||||
|
|
||||||
|
const anthropic = new Anthropic({
|
||||||
|
baseURL: "http://localhost:11434",
|
||||||
|
apiKey: "ollama",
|
||||||
|
});
|
||||||
|
|
||||||
|
const message = await anthropic.messages.create({
|
||||||
|
model: "qwen3-coder",
|
||||||
|
max_tokens: 1024,
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the current weather in a location",
|
||||||
|
input_schema: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
location: {
|
||||||
|
type: "string",
|
||||||
|
description: "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const block of message.content) {
|
||||||
|
if (block.type === "tool_use") {
|
||||||
|
console.log("Tool:", block.name);
|
||||||
|
console.log("Input:", block.input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell tools.sh
|
||||||
|
curl -X POST http://localhost:11434/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "qwen3-coder",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a location",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Using with Claude Code
|
||||||
|
|
||||||
|
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
|
```
|
||||||
|
|
||||||
|
Or set the environment variables in your shell profile:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
|
export ANTHROPIC_API_KEY=ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run Claude Code with any Ollama model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Local models
|
||||||
|
claude --model qwen3-coder
|
||||||
|
claude --model gpt-oss:20b
|
||||||
|
|
||||||
|
# Cloud models
|
||||||
|
claude --model glm-4.7:cloud
|
||||||
|
claude --model minimax-m2.1:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
### `/v1/messages`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Messages
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] System prompts
|
||||||
|
- [x] Multi-turn conversations
|
||||||
|
- [x] Vision (images)
|
||||||
|
- [x] Tools (function calling)
|
||||||
|
- [x] Tool results
|
||||||
|
- [x] Thinking/extended thinking
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [x] `messages`
|
||||||
|
- [x] Text `content`
|
||||||
|
- [x] Image `content` (base64)
|
||||||
|
- [x] Array of content blocks
|
||||||
|
- [x] `tool_use` blocks
|
||||||
|
- [x] `tool_result` blocks
|
||||||
|
- [x] `thinking` blocks
|
||||||
|
- [x] `system` (string or array)
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `top_k`
|
||||||
|
- [x] `stop_sequences`
|
||||||
|
- [x] `tools`
|
||||||
|
- [x] `thinking`
|
||||||
|
- [ ] `tool_choice`
|
||||||
|
- [ ] `metadata`
|
||||||
|
|
||||||
|
#### Supported response fields
|
||||||
|
|
||||||
|
- [x] `id`
|
||||||
|
- [x] `type`
|
||||||
|
- [x] `role`
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `content` (text, tool_use, thinking blocks)
|
||||||
|
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||||
|
- [x] `usage` (input_tokens, output_tokens)
|
||||||
|
|
||||||
|
#### Streaming events
|
||||||
|
|
||||||
|
- [x] `message_start`
|
||||||
|
- [x] `content_block_start`
|
||||||
|
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||||
|
- [x] `content_block_stop`
|
||||||
|
- [x] `message_delta`
|
||||||
|
- [x] `message_stop`
|
||||||
|
- [x] `ping`
|
||||||
|
- [x] `error`
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
Ollama supports both local and cloud models.
|
||||||
|
|
||||||
|
### Local models
|
||||||
|
|
||||||
|
Pull a local model before use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama pull qwen3-coder
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended local models:
|
||||||
|
- `qwen3-coder` - Excellent for coding tasks
|
||||||
|
- `gpt-oss:20b` - Strong general-purpose model
|
||||||
|
|
||||||
|
### Cloud models
|
||||||
|
|
||||||
|
Cloud models are available immediately without pulling:
|
||||||
|
|
||||||
|
- `glm-4.7:cloud` - High-performance cloud model
|
||||||
|
- `minimax-m2.1:cloud` - Fast cloud model
|
||||||
|
|
||||||
|
### Default model names
|
||||||
|
|
||||||
|
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama cp qwen3-coder claude-3-5-sonnet
|
||||||
|
```
|
||||||
|
|
||||||
|
Afterwards, this new model name can be specified in the `model` field:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/v1/messages \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Differences from the Anthropic API
|
||||||
|
|
||||||
|
### Behavior differences
|
||||||
|
|
||||||
|
- API key is accepted but not validated
|
||||||
|
- `anthropic-version` header is accepted but not used
|
||||||
|
- Token counts are approximations based on the underlying model's tokenizer
|
||||||
|
|
||||||
|
### Not supported
|
||||||
|
|
||||||
|
The following Anthropic API features are not currently supported:
|
||||||
|
|
||||||
|
| Feature | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||||
|
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||||
|
| `metadata` | Request metadata (user_id) |
|
||||||
|
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||||
|
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||||
|
| Citations | `citations` content blocks |
|
||||||
|
| PDF support | `document` content blocks with PDF files |
|
||||||
|
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||||
|
|
||||||
|
### Partial support
|
||||||
|
|
||||||
|
| Feature | Status |
|
||||||
|
|---------|--------|
|
||||||
|
| Image content | Base64 images supported; URL images not supported |
|
||||||
|
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||||
File diff suppressed because one or more lines are too long
@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
|
|||||||
```shell
|
```shell
|
||||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||||
"model": "qwen3",
|
"model": "qwen3",
|
||||||
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
|
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
|
|||||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||||
"model": "qwen3",
|
"model": "qwen3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "What's the temperature in New York?"},
|
{"role": "user", "content": "What is the temperature in New York?"},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
|
|||||||
}
|
}
|
||||||
return temperatures.get(city, "Unknown")
|
return temperatures.get(city, "Unknown")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
|
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
|
||||||
|
|
||||||
# pass functions directly as tools in the tools list or as a JSON schema
|
# pass functions directly as tools in the tools list or as a JSON schema
|
||||||
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
||||||
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||||
|
|
||||||
const response = await ollama.chat({
|
const response = await ollama.chat({
|
||||||
model: 'qwen3',
|
model: 'qwen3',
|
||||||
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
|
|||||||
return temperatures.get(city, 'Unknown')
|
return temperatures.get(city, 'Unknown')
|
||||||
|
|
||||||
|
|
||||||
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
|
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
stream = chat(
|
stream = chat(
|
||||||
@@ -684,7 +684,7 @@ const getTemperatureTool = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function agentLoop() {
|
async function agentLoop() {
|
||||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const stream = await ollama.chat({
|
const stream = await ollama.chat({
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
|
|||||||
}],
|
}],
|
||||||
"stream": false
|
"stream": false
|
||||||
}'
|
}'
|
||||||
"
|
|
||||||
```
|
```
|
||||||
</Tab>
|
</Tab>
|
||||||
<Tab title="Python">
|
<Tab title="Python">
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ Install prerequisites:
|
|||||||
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||||
- (Optional) NVIDIA GPU support
|
- (Optional) NVIDIA GPU support
|
||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
|
- (Optional) VULKAN GPU support
|
||||||
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
@@ -57,6 +59,17 @@ cmake -B build
|
|||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> Building for Vulkan requires VULKAN_SDK environment variable:
|
||||||
|
>
|
||||||
|
> PowerShell
|
||||||
|
> ```powershell
|
||||||
|
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
|
||||||
|
> ```
|
||||||
|
> CMD
|
||||||
|
> ```cmd
|
||||||
|
> set VULKAN_SDK=C:\VulkanSDK\<version>
|
||||||
|
> ```
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Building for ROCm requires additional flags:
|
> Building for ROCm requires additional flags:
|
||||||
> ```
|
> ```
|
||||||
@@ -65,6 +78,7 @@ cmake --build build --config Release
|
|||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Lastly, run Ollama:
|
Lastly, run Ollama:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -84,7 +98,9 @@ Install prerequisites:
|
|||||||
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
||||||
- (Optional) NVIDIA GPU support
|
- (Optional) NVIDIA GPU support
|
||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- (Optional) VULKAN GPU support
|
||||||
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,9 @@
|
|||||||
"codeblocks": "system"
|
"codeblocks": "system"
|
||||||
},
|
},
|
||||||
"contextual": {
|
"contextual": {
|
||||||
"options": ["copy"]
|
"options": [
|
||||||
|
"copy"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"navbar": {
|
"navbar": {
|
||||||
"links": [
|
"links": [
|
||||||
@@ -52,7 +54,9 @@
|
|||||||
"display": "simple"
|
"display": "simple"
|
||||||
},
|
},
|
||||||
"examples": {
|
"examples": {
|
||||||
"languages": ["curl"]
|
"languages": [
|
||||||
|
"curl"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"redirects": [
|
"redirects": [
|
||||||
@@ -97,6 +101,7 @@
|
|||||||
{
|
{
|
||||||
"group": "Integrations",
|
"group": "Integrations",
|
||||||
"pages": [
|
"pages": [
|
||||||
|
"/integrations/claude-code",
|
||||||
"/integrations/vscode",
|
"/integrations/vscode",
|
||||||
"/integrations/jetbrains",
|
"/integrations/jetbrains",
|
||||||
"/integrations/codex",
|
"/integrations/codex",
|
||||||
@@ -139,7 +144,8 @@
|
|||||||
"/api/streaming",
|
"/api/streaming",
|
||||||
"/api/usage",
|
"/api/usage",
|
||||||
"/api/errors",
|
"/api/errors",
|
||||||
"/api/openai-compatibility"
|
"/api/openai-compatibility",
|
||||||
|
"/api/anthropic-compatibility"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
|
|
||||||
## How can I view the logs?
|
## How can I view the logs?
|
||||||
|
|
||||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||||
|
|
||||||
## Is my GPU compatible with Ollama?
|
## Is my GPU compatible with Ollama?
|
||||||
|
|
||||||
Please refer to the [GPU docs](./gpu.md).
|
Please refer to the [GPU docs](./gpu).
|
||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
|
|||||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
|||||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||||
|
|
||||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||||
|
|
||||||
### GPU Selection
|
### GPU Selection
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
|||||||
|
|
||||||
Ollama supports the following AMD GPUs via the ROCm library:
|
Ollama supports the following AMD GPUs via the ROCm library:
|
||||||
|
|
||||||
> [!NOTE]
|
> **NOTE:**
|
||||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||||
|
|
||||||
|
|
||||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
|||||||
|
|
||||||
## Vulkan GPU Support
|
## Vulkan GPU Support
|
||||||
|
|
||||||
> [!NOTE]
|
> **NOTE:**
|
||||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||||
|
|
||||||
Additional GPU support on Windows and Linux is provided via
|
Additional GPU support on Windows and Linux is provided via
|
||||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
|||||||
|
|
||||||
To select specific Vulkan GPU(s), you can set the environment variable
|
To select specific Vulkan GPU(s), you can set the environment variable
|
||||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||||
69
docs/integrations/claude-code.mdx
Normal file
69
docs/integrations/claude-code.mdx
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
---
|
||||||
|
title: Claude Code
|
||||||
|
---
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```shell macOS / Linux
|
||||||
|
curl -fsSL https://claude.ai/install.sh | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell Windows
|
||||||
|
irm https://claude.ai/install.ps1 | iex
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|
||||||
|
1. Set the environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
|
export ANTHROPIC_API_KEY=ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run Claude Code with an Ollama model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
claude --model qwen3-coder
|
||||||
|
```
|
||||||
|
|
||||||
|
Or run with environment variables inline:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
|
```
|
||||||
|
|
||||||
|
## Connecting to ollama.com
|
||||||
|
|
||||||
|
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||||
|
2. Set the environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||||
|
export ANTHROPIC_API_KEY=<your-api-key>
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run Claude Code with a cloud model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
claude --model glm-4.7:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
### Cloud models
|
||||||
|
- `glm-4.7:cloud` - High-performance cloud model
|
||||||
|
- `minimax-m2.1:cloud` - Fast cloud model
|
||||||
|
- `qwen3-coder:480b` - Large coding model
|
||||||
|
|
||||||
|
### Local models
|
||||||
|
- `qwen3-coder` - Excellent for coding tasks
|
||||||
|
- `gpt-oss:20b` - Strong general-purpose model
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Linux
|
title: "Linux"
|
||||||
---
|
---
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
@@ -13,8 +13,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
## Manual install
|
## Manual install
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
If you are upgrading from a prior version, you should remove the old libraries
|
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||||
with `sudo rm -rf /usr/lib/ollama` first.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
|||||||
```
|
```
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||||
kernel source, the version is older and may not support all ROCm features. We
|
|
||||||
recommend you install the latest driver from
|
|
||||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
|
||||||
GPU.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
## Customizing
|
## Customizing
|
||||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
|||||||
sudo userdel ollama
|
sudo userdel ollama
|
||||||
sudo groupdel ollama
|
sudo groupdel ollama
|
||||||
sudo rm -r /usr/share/ollama
|
sudo rm -r /usr/share/ollama
|
||||||
```
|
```
|
||||||
@@ -41,6 +41,7 @@ INSTRUCTION arguments
|
|||||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||||
| [`MESSAGE`](#message) | Specify message history. |
|
| [`MESSAGE`](#message) | Specify message history. |
|
||||||
|
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
|
|||||||
MESSAGE assistant yes
|
MESSAGE assistant yes
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### REQUIRES
|
||||||
|
|
||||||
|
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||||
|
|
||||||
|
```
|
||||||
|
REQUIRES <version>
|
||||||
|
```
|
||||||
|
|
||||||
|
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||||
|
|||||||
46
docs/tools/extract-examples/README.md
Normal file
46
docs/tools/extract-examples/README.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# extract-examples
|
||||||
|
|
||||||
|
Extracts code examples from MDX files to a temp directory so you can run them.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go run docs/tools/extract-examples/main.go <mdx-file>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
|
||||||
|
```
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
```
|
||||||
|
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
|
||||||
|
- 01_basic.py
|
||||||
|
- 01_basic.js
|
||||||
|
- 01_basic.sh
|
||||||
|
- 02_responses.py
|
||||||
|
- 02_responses.js
|
||||||
|
- 02_responses.sh
|
||||||
|
- 03_vision.py
|
||||||
|
- 03_vision.js
|
||||||
|
- 03_vision.sh
|
||||||
|
|
||||||
|
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
|
||||||
|
To run examples:
|
||||||
|
|
||||||
|
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
npm install # for JS examples
|
||||||
|
|
||||||
|
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
|
||||||
|
```
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
|
||||||
|
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
|
||||||
|
- Writes all extracted files to a temp directory
|
||||||
137
docs/tools/extract-examples/main.go
Normal file
137
docs/tools/extract-examples/main.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) < 2 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
mdxFile := os.Args[1]
|
||||||
|
|
||||||
|
f, err := os.Open(mdxFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Create temp directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
|
||||||
|
|
||||||
|
// Patterns
|
||||||
|
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
|
||||||
|
codeGroupStart := regexp.MustCompile("^<CodeGroup")
|
||||||
|
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
inCodeBlock := false
|
||||||
|
inCodeGroup := false
|
||||||
|
var currentFile string
|
||||||
|
var content strings.Builder
|
||||||
|
count := 0
|
||||||
|
codeGroupNum := 0
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
// Track CodeGroup boundaries
|
||||||
|
if codeGroupStart.MatchString(line) {
|
||||||
|
inCodeGroup = true
|
||||||
|
codeGroupNum++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if codeGroupEnd.MatchString(line) {
|
||||||
|
inCodeGroup = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if inCodeBlock {
|
||||||
|
if line == "```" {
|
||||||
|
// End of code block - write file
|
||||||
|
if currentFile != "" {
|
||||||
|
outPath := filepath.Join(tempDir, currentFile)
|
||||||
|
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" - %s\n", currentFile)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inCodeBlock = false
|
||||||
|
currentFile = ""
|
||||||
|
content.Reset()
|
||||||
|
} else {
|
||||||
|
content.WriteString(line)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
|
||||||
|
inCodeBlock = true
|
||||||
|
filename := matches[2]
|
||||||
|
// Prefix with CodeGroup number if inside a CodeGroup
|
||||||
|
if inCodeGroup {
|
||||||
|
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
|
||||||
|
} else {
|
||||||
|
currentFile = filename
|
||||||
|
}
|
||||||
|
content.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write package.json for JavaScript dependencies
|
||||||
|
packageJSON := `{
|
||||||
|
"name": "mdx-examples",
|
||||||
|
"type": "module",
|
||||||
|
"dependencies": {
|
||||||
|
"openai": "^4",
|
||||||
|
"ollama": "^0.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write pyproject.toml for Python dependencies
|
||||||
|
pyprojectTOML := `[project]
|
||||||
|
name = "mdx-examples"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"openai",
|
||||||
|
"ollama",
|
||||||
|
]
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("To run examples:\n")
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
|
||||||
|
}
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# Troubleshooting
|
|
||||||
|
|
||||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
|
||||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
|||||||
|
|
||||||
### Linux NVIDIA Troubleshooting
|
### Linux NVIDIA Troubleshooting
|
||||||
|
|
||||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||||
|
|
||||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package fs
|
package fs
|
||||||
|
|
||||||
|
import "iter"
|
||||||
|
|
||||||
type Config interface {
|
type Config interface {
|
||||||
Architecture() string
|
Architecture() string
|
||||||
String(string, ...string) string
|
String(string, ...string) string
|
||||||
@@ -11,4 +13,8 @@ type Config interface {
|
|||||||
Ints(string, ...[]int32) []int32
|
Ints(string, ...[]int32) []int32
|
||||||
Floats(string, ...[]float32) []float32
|
Floats(string, ...[]float32) []float32
|
||||||
Bools(string, ...[]bool) []bool
|
Bools(string, ...[]bool) []bool
|
||||||
|
|
||||||
|
Len() int
|
||||||
|
Keys() iter.Seq[string]
|
||||||
|
Value(key string) any
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
@@ -238,20 +241,34 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
|||||||
return val.values
|
return val.values
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) Len() int {
|
||||||
|
return len(kv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Keys() iter.Seq[string] {
|
||||||
|
return maps.Keys(kv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Value(key string) any {
|
||||||
|
return kv[key]
|
||||||
|
}
|
||||||
|
|
||||||
func (kv KV) OllamaEngineRequired() bool {
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
|
"deepseek2",
|
||||||
|
"deepseekocr",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gemma3n",
|
"gemma3n",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
"llama4",
|
"llama4",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
"mllama",
|
"mllama",
|
||||||
|
"nomic-bert",
|
||||||
|
"olmo3",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
"deepseekocr",
|
|
||||||
"deepseek2",
|
|
||||||
"nomic-bert",
|
|
||||||
}, kv.Architecture())
|
}, kv.Architecture())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -550,7 +567,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
context *= uint64(numParallel)
|
context *= uint64(numParallel)
|
||||||
|
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
@@ -791,7 +808,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||||
if useFlashAttention {
|
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||||
// rough estimate of graph size with flash attention on
|
// rough estimate of graph size with flash attention on
|
||||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||||
}
|
}
|
||||||
@@ -809,6 +826,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
|||||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||||
|
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||||
|
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
func (f GGML) SupportsFlashAttention() bool {
|
func (f GGML) SupportsFlashAttention() bool {
|
||||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||||
@@ -829,8 +854,11 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
// FlashAttention checks if the model should enable flash attention
|
// FlashAttention checks if the model should enable flash attention
|
||||||
func (f GGML) FlashAttention() bool {
|
func (f GGML) FlashAttention() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
|
"mistral3",
|
||||||
|
"olmo3",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
}, f.KV().String("general.architecture"))
|
}, f.KV().String("general.architecture"))
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
|||||||
return binary.Write(w, binary.LittleEndian, s)
|
return binary.Write(w, binary.LittleEndian, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||||
arch := kv.String("general.architecture")
|
arch := kv.String("general.architecture")
|
||||||
if arch == "" {
|
if arch == "" {
|
||||||
return fmt.Errorf("architecture not set")
|
return fmt.Errorf("architecture not set")
|
||||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
for _, key := range slices.Sorted(kv.Keys()) {
|
||||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -597,6 +597,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
|
case int32:
|
||||||
|
err = writeGGUF(ws, ggufTypeInt32, v)
|
||||||
|
case int64:
|
||||||
|
err = writeGGUF(ws, ggufTypeInt64, v)
|
||||||
case uint32, FileType:
|
case uint32, FileType:
|
||||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||||
case uint64:
|
case uint64:
|
||||||
@@ -611,6 +615,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
|||||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||||
case *array[int32]:
|
case *array[int32]:
|
||||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||||
|
case []int64:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeInt64, v)
|
||||||
|
case *array[int64]:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
|
||||||
case []uint32:
|
case []uint32:
|
||||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||||
case *array[uint32]:
|
case *array[uint32]:
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.alignment": uint32(16),
|
"general.alignment": uint32(16),
|
||||||
"test.key": "value",
|
"test.key": "value",
|
||||||
|
"test.int32_key": int32(-42),
|
||||||
|
"test.int64_key": int64(-9223372036854775808),
|
||||||
|
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
|
||||||
|
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
|
||||||
"attention.key": "value2",
|
"attention.key": "value2",
|
||||||
"tokenizer.key": "value3",
|
"tokenizer.key": "value3",
|
||||||
"adapter.key": "value4",
|
"adapter.key": "value4",
|
||||||
@@ -55,7 +59,7 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
ff, err := Decode(r, 0)
|
ff, err := Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -65,15 +69,19 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
"general.alignment": uint32(16),
|
"general.alignment": uint32(16),
|
||||||
"general.parameter_count": uint64(54),
|
"general.parameter_count": uint64(54),
|
||||||
"test.key": "value",
|
"test.key": "value",
|
||||||
|
"test.int32_key": int32(-42),
|
||||||
|
"test.int64_key": int64(-9223372036854775808),
|
||||||
|
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
|
||||||
|
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
|
||||||
"test.attention.key": "value2",
|
"test.attention.key": "value2",
|
||||||
"tokenizer.key": "value3",
|
"tokenizer.key": "value3",
|
||||||
"adapter.key": "value4",
|
"adapter.key": "value4",
|
||||||
}, ff.KV()); diff != "" {
|
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
|
||||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(Tensors{
|
if diff := cmp.Diff(Tensors{
|
||||||
Offset: 800,
|
Offset: 992,
|
||||||
items: []*Tensor{
|
items: []*Tensor{
|
||||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||||
|
|||||||
19
go.mod
19
go.mod
@@ -15,8 +15,8 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/x448/float16 v0.8.4
|
github.com/x448/float16 v0.8.4
|
||||||
golang.org/x/sync v0.12.0
|
golang.org/x/sync v0.17.0
|
||||||
golang.org/x/sys v0.36.0
|
golang.org/x/sys v0.37.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -28,13 +28,17 @@ require (
|
|||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
golang.org/x/tools v0.30.0
|
golang.org/x/mod v0.30.0
|
||||||
|
golang.org/x/tools v0.38.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
gonum.org/v1/gonum v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||||
|
github.com/buger/jsonparser v1.1.1 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/chewxy/hm v1.0.0 // indirect
|
github.com/chewxy/hm v1.0.0 // indirect
|
||||||
github.com/chewxy/math32 v1.11.0 // indirect
|
github.com/chewxy/math32 v1.11.0 // indirect
|
||||||
@@ -44,6 +48,7 @@ require (
|
|||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/mailru/easyjson v0.7.7 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
@@ -76,11 +81,11 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/net v0.38.0 // indirect
|
golang.org/x/net v0.46.0 // indirect
|
||||||
golang.org/x/term v0.30.0
|
golang.org/x/term v0.36.0
|
||||||
golang.org/x/text v0.23.0
|
golang.org/x/text v0.30.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
39
go.sum
39
go.sum
@@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
|||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
|
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||||
|
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
@@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
|||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||||
@@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
|||||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||||
|
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||||
@@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||||
@@ -224,8 +233,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
@@ -255,6 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
|||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
|
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||||
|
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -267,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
|||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -278,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -295,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -319,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount != 6 {
|
if res.PromptEvalCount != 8 {
|
||||||
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
|
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount != 12 {
|
if res.PromptEvalCount != 16 {
|
||||||
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
|
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
request api.EmbedRequest
|
request api.EmbedRequest
|
||||||
check func(*api.EmbedResponse, error)
|
check func(*testing.T, *api.EmbedResponse, error)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "target truncation",
|
name: "target truncation",
|
||||||
@@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why",
|
Input: "why",
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncFalse,
|
Truncate: &truncFalse,
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input exceeds maximum context length" {
|
if err.Error() != "the input length exceeds the context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "input after truncate error",
|
name: "input after truncate error with context length of 1",
|
||||||
request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 1},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 0},
|
Options: map[string]any{"num_ctx": 0},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||||
Options: map[string]any{"num_ctx": 16},
|
Options: map[string]any{"num_ctx": 16},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
for _, req := range cases {
|
for _, req := range cases {
|
||||||
t.Run(req.name, func(t *testing.T) {
|
t.Run(req.name, func(t *testing.T) {
|
||||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
resp, err := embedTestHelper(ctx, client, t, req.request)
|
||||||
|
req.check(t, resp, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -409,3 +414,230 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||||||
|
|
||||||
return client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedTruncation(t *testing.T) {
|
||||||
|
// Use test deadline if set, otherwise default to 2 minutes
|
||||||
|
timeout := 2 * time.Minute
|
||||||
|
if deadline, ok := t.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give each model its own budget to account for first-time pulls/loads
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
t.Run("truncation batch", func(t *testing.T) {
|
||||||
|
truncTrue := true
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 3 {
|
||||||
|
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount > 90 {
|
||||||
|
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||||
|
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||||
|
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{"test", "test", "test"},
|
||||||
|
}
|
||||||
|
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCount := baseRes.PromptEvalCount * 3
|
||||||
|
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||||
|
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||||
|
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||||
|
func TestEmbedLargeInput(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
// Test with progressively larger inputs
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputWords int
|
||||||
|
}{
|
||||||
|
{"medium_input_256_words", 256},
|
||||||
|
{"large_input_512_words", 512},
|
||||||
|
{"very_large_input_800_words", 800},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
words := make([]string, tc.inputWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = "word"
|
||||||
|
}
|
||||||
|
input := strings.Join(words, " ")
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: input,
|
||||||
|
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) == 0 {
|
||||||
|
t.Fatal("expected non-empty embedding")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||||
|
// properly preserve their HTTP status codes when returned to the client.
|
||||||
|
// This test specifically checks the error handling path in EmbedHandler
|
||||||
|
// where api.StatusError errors should maintain their original status code.
|
||||||
|
func TestEmbedStatusCode(t *testing.T) {
|
||||||
|
// Use test deadline if set, otherwise default to 2 minutes
|
||||||
|
timeout := 2 * time.Minute
|
||||||
|
if deadline, ok := t.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
// Pull the model if needed
|
||||||
|
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
longInput := strings.Repeat("word ", 100)
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: longInput,
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||||
|
// not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error message is meaningful
|
||||||
|
if !strings.Contains(err.Error(), "context length") {
|
||||||
|
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{
|
||||||
|
"short input",
|
||||||
|
strings.Repeat("very long input ", 100),
|
||||||
|
"another short input",
|
||||||
|
},
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,15 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
func TestAPIToolCalling(t *testing.T) {
|
func TestAPIToolCalling(t *testing.T) {
|
||||||
initialTimeout := 60 * time.Second
|
initialTimeout := 60 * time.Second
|
||||||
streamTimeout := 60 * time.Second
|
streamTimeout := 60 * time.Second
|
||||||
@@ -30,6 +39,7 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
"mistral": 6,
|
"mistral": 6,
|
||||||
"qwen2.5": 6,
|
"qwen2.5": 6,
|
||||||
"qwen2": 6,
|
"qwen2": 6,
|
||||||
|
"ministral-3": 20,
|
||||||
"mistral-nemo": 9,
|
"mistral-nemo": 9,
|
||||||
"mistral-small": 16,
|
"mistral-small": 16,
|
||||||
"mixtral:8x22b": 80,
|
"mixtral:8x22b": 80,
|
||||||
@@ -56,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state, e.g. San Francisco, CA",
|
Description: "The city and state, e.g. San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
94
internal/orderedmap/orderedmap.go
Normal file
94
internal/orderedmap/orderedmap.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||||
|
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||||
|
package orderedmap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"iter"
|
||||||
|
|
||||||
|
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map is a generic ordered map that maintains insertion order.
|
||||||
|
type Map[K comparable, V any] struct {
|
||||||
|
om *orderedmap.OrderedMap[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new empty ordered map.
|
||||||
|
func New[K comparable, V any]() *Map[K, V] {
|
||||||
|
return &Map[K, V]{
|
||||||
|
om: orderedmap.New[K, V](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key.
|
||||||
|
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
var zero V
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
return m.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||||
|
// but its position in the iteration order is preserved. If the key is new,
|
||||||
|
// it is appended to the end.
|
||||||
|
func (m *Map[K, V]) Set(key K, value V) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.om == nil {
|
||||||
|
m.om = orderedmap.New[K, V]()
|
||||||
|
}
|
||||||
|
m.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of entries.
|
||||||
|
func (m *Map[K, V]) Len() int {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return m.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all key-value pairs in insertion order.
|
||||||
|
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||||
|
return func(yield func(K, V) bool) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||||
|
if !yield(pair.Key, pair.Value) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts to a regular Go map.
|
||||||
|
// Note: The resulting map does not preserve order.
|
||||||
|
func (m *Map[K, V]) ToMap() map[K]V {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make(map[K]V, m.om.Len())
|
||||||
|
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||||
|
result[pair.Key] = pair.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||||
|
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(m.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||||
|
// order of keys in the JSON input.
|
||||||
|
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||||
|
m.om = orderedmap.New[K, V]()
|
||||||
|
return json.Unmarshal(data, &m.om)
|
||||||
|
}
|
||||||
348
internal/orderedmap/orderedmap_test.go
Normal file
348
internal/orderedmap/orderedmap_test.go
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
package orderedmap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMap_BasicOperations(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Test empty map
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||||
|
}
|
||||||
|
v, ok := m.Get("a")
|
||||||
|
if ok {
|
||||||
|
t.Error("expected Get on empty map to return false")
|
||||||
|
}
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("expected zero value, got %d", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Set and Get
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
m.Set("c", 3)
|
||||||
|
|
||||||
|
if m.Len() != 3 {
|
||||||
|
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("a")
|
||||||
|
if !ok || v != 1 {
|
||||||
|
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("b")
|
||||||
|
if !ok || v != 2 {
|
||||||
|
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("c")
|
||||||
|
if !ok || v != 3 {
|
||||||
|
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test updating existing key preserves position
|
||||||
|
m.Set("a", 10)
|
||||||
|
v, ok = m.Get("a")
|
||||||
|
if !ok || v != 10 {
|
||||||
|
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
if m.Len() != 3 {
|
||||||
|
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Insert in non-alphabetical order
|
||||||
|
m.Set("z", 1)
|
||||||
|
m.Set("a", 2)
|
||||||
|
m.Set("m", 3)
|
||||||
|
m.Set("b", 4)
|
||||||
|
|
||||||
|
// Verify iteration order matches insertion order
|
||||||
|
var keys []string
|
||||||
|
var values []int
|
||||||
|
for k, v := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKeys := []string{"z", "a", "m", "b"}
|
||||||
|
expectedValues := []int{1, 2, 3, 4}
|
||||||
|
|
||||||
|
if !slices.Equal(keys, expectedKeys) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||||
|
}
|
||||||
|
if !slices.Equal(values, expectedValues) {
|
||||||
|
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
m.Set("first", 1)
|
||||||
|
m.Set("second", 2)
|
||||||
|
m.Set("third", 3)
|
||||||
|
|
||||||
|
// Update middle element
|
||||||
|
m.Set("second", 20)
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order should still be first, second, third
|
||||||
|
expected := []string{"first", "second", "third"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Insert in non-alphabetical order
|
||||||
|
m.Set("z", 1)
|
||||||
|
m.Set("a", 2)
|
||||||
|
m.Set("m", 3)
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON should preserve insertion order, not alphabetical
|
||||||
|
expected := `{"z":1,"a":2,"m":3}`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||||
|
// JSON with non-alphabetical key order
|
||||||
|
jsonData := `{"z":1,"a":2,"m":3}`
|
||||||
|
|
||||||
|
m := New[string, int]()
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"z", "a", "m"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||||
|
// Test that unmarshal -> marshal produces identical JSON
|
||||||
|
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||||
|
|
||||||
|
m := New[string, string]()
|
||||||
|
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != original {
|
||||||
|
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_ToMap(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
|
||||||
|
regular := m.ToMap()
|
||||||
|
|
||||||
|
if len(regular) != 2 {
|
||||||
|
t.Errorf("expected len 2, got %d", len(regular))
|
||||||
|
}
|
||||||
|
if regular["a"] != 1 {
|
||||||
|
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||||
|
}
|
||||||
|
if regular["b"] != 2 {
|
||||||
|
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_NilSafety(t *testing.T) {
|
||||||
|
var m *Map[string, int]
|
||||||
|
|
||||||
|
// All operations should be safe on nil
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := m.Get("a")
|
||||||
|
if ok {
|
||||||
|
t.Error("expected Get on nil map to return false")
|
||||||
|
}
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("expected zero value from nil map, got %d", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set on nil is a no-op
|
||||||
|
m.Set("a", 1)
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns empty iterator
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns nil
|
||||||
|
if m.ToMap() != nil {
|
||||||
|
t.Error("expected ToMap to return nil on nil map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON returns null
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "null" {
|
||||||
|
t.Errorf("expected null, got %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "{}" {
|
||||||
|
t.Errorf("expected {}, got %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_NestedValues(t *testing.T) {
|
||||||
|
m := New[string, any]()
|
||||||
|
m.Set("string", "hello")
|
||||||
|
m.Set("number", 42)
|
||||||
|
m.Set("bool", true)
|
||||||
|
m.Set("nested", map[string]int{"x": 1})
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
m.Set("c", 3)
|
||||||
|
m.Set("d", 4)
|
||||||
|
|
||||||
|
// Collect only first 2
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
if len(keys) == 2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"a", "b"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_IntegerKeys(t *testing.T) {
|
||||||
|
m := New[int, string]()
|
||||||
|
m.Set(3, "three")
|
||||||
|
m.Set(1, "one")
|
||||||
|
m.Set(2, "two")
|
||||||
|
|
||||||
|
var keys []int
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should preserve insertion order, not numerical order
|
||||||
|
expected := []int{3, 1, 2}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("existing", 999)
|
||||||
|
|
||||||
|
// Unmarshal should replace contents
|
||||||
|
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := m.Get("existing")
|
||||||
|
if ok {
|
||||||
|
t.Error("existing key should be gone after unmarshal")
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := m.Get("new")
|
||||||
|
if !ok || v != 1 {
|
||||||
|
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Create many keys in specific order
|
||||||
|
keys := make([]string, 100)
|
||||||
|
for i := range 100 {
|
||||||
|
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||||
|
if i >= 26 {
|
||||||
|
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, k := range keys {
|
||||||
|
m.Set(k, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify order preserved
|
||||||
|
var resultKeys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
resultKeys = append(resultKeys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(keys, resultKeys) {
|
||||||
|
t.Error("large map should preserve insertion order")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -140,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||||||
c.config.CachePadding = 1
|
c.config.CachePadding = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.config.MaskBatchPadding == 0 {
|
|
||||||
c.config.MaskBatchPadding = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskDType == ml.DTypeOther {
|
if c.config.MaskDType == ml.DTypeOther {
|
||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
@@ -364,15 +360,12 @@ func roundUp(length, pad int) int {
|
|||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
// Align and pad the two dimensions as required by the backend
|
|
||||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
|
||||||
|
|
||||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
|
||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, c.curBatchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
enabled := !slices.Contains(c.opts.Except, i)
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
@@ -386,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||||
// has already been masked out because the sequence doesn't match.
|
|
||||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
|
||||||
mask[i] = float32(math.Inf(-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||||
|
|||||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
|||||||
int LLAMA_BUILD_NUMBER = 0;
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
|
char const *LLAMA_COMMIT = "ec98e2002";
|
||||||
char const *LLAMA_COMPILER = "";
|
char const *LLAMA_COMPILER = "";
|
||||||
char const *LLAMA_BUILD_TARGET = "";
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
|
|||||||
@@ -17,11 +17,17 @@ include /tools/mtmd/clip.cpp
|
|||||||
include /tools/mtmd/mtmd.cpp
|
include /tools/mtmd/mtmd.cpp
|
||||||
include /tools/mtmd/mtmd-audio.cpp
|
include /tools/mtmd/mtmd-audio.cpp
|
||||||
include /tools/mtmd/mtmd-helper.cpp
|
include /tools/mtmd/mtmd-helper.cpp
|
||||||
|
include /tools/mtmd/models/
|
||||||
|
include /tools/mtmd/models/*.h
|
||||||
|
include /tools/mtmd/models/*.cpp
|
||||||
include /src/
|
include /src/
|
||||||
include /src/llama.*
|
include /src/llama.*
|
||||||
include /src/llama-*.*
|
include /src/llama-*.*
|
||||||
include /src/unicode-data.*
|
include /src/unicode-data.*
|
||||||
include /src/unicode.*
|
include /src/unicode.*
|
||||||
|
include /src/models/
|
||||||
|
include /src/models/*.h
|
||||||
|
include /src/models/*.cpp
|
||||||
include /vendor/
|
include /vendor/
|
||||||
include /vendor/miniaudio/
|
include /vendor/miniaudio/
|
||||||
include /vendor/miniaudio/*.h
|
include /vendor/miniaudio/*.h
|
||||||
|
|||||||
359
llama/llama.cpp/common/common.cpp
vendored
359
llama/llama.cpp/common/common.cpp
vendored
@@ -8,6 +8,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@@ -26,7 +27,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -60,6 +60,14 @@
|
|||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
common_time_meas::~common_time_meas() {
|
||||||
|
if (t_start_us >= 0) {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU utils
|
// CPU utils
|
||||||
//
|
//
|
||||||
@@ -355,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
|||||||
}
|
}
|
||||||
|
|
||||||
void common_init() {
|
void common_init() {
|
||||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
llama_log_set(common_log_default_callback, NULL);
|
||||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
|
||||||
common_log_add(common_log_main(), level, "%s", text);
|
|
||||||
}
|
|
||||||
}, NULL);
|
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
const char * build_type = "";
|
const char * build_type = "";
|
||||||
@@ -690,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
|||||||
|
|
||||||
// Validate if a filename is safe to use
|
// Validate if a filename is safe to use
|
||||||
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||||
bool fs_validate_filename(const std::string & filename) {
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||||
if (!filename.length()) {
|
if (!filename.length()) {
|
||||||
// Empty filename invalid
|
// Empty filename invalid
|
||||||
return false;
|
return false;
|
||||||
@@ -750,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|
|||||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||||
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
|| c == ':' || c == '*' // Illegal characters
|
||||||
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!allow_subdirs && (c == '/' || c == '\\')) {
|
||||||
|
// Subdirectories not allowed, reject path separators
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||||
@@ -778,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) {
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||||
|
if (str.empty()) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||||
|
|
||||||
|
if (size <= 0) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::wstring wstr(size, 0);
|
||||||
|
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||||
|
|
||||||
|
return wstr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// returns true if successful, false otherwise
|
// returns true if successful, false otherwise
|
||||||
bool fs_create_directory_with_parents(const std::string & path) {
|
bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
std::wstring wpath = utf8_to_wstring(path);
|
||||||
std::wstring wpath = converter.from_bytes(path);
|
|
||||||
|
|
||||||
// if the path already exists, check whether it's a directory
|
// if the path already exists, check whether it's a directory
|
||||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||||
@@ -855,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
|||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fs_is_directory(const std::string & path) {
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
|
||||||
|
}
|
||||||
|
|
||||||
std::string fs_get_cache_directory() {
|
std::string fs_get_cache_directory() {
|
||||||
std::string cache_directory = "";
|
std::string cache_directory = "";
|
||||||
auto ensure_trailing_slash = [](std::string p) {
|
auto ensure_trailing_slash = [](std::string p) {
|
||||||
@@ -889,6 +920,8 @@ std::string fs_get_cache_directory() {
|
|||||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
cache_directory = std::getenv("LOCALAPPDATA");
|
cache_directory = std::getenv("LOCALAPPDATA");
|
||||||
|
#elif defined(__EMSCRIPTEN__)
|
||||||
|
GGML_ABORT("not implemented on this platform");
|
||||||
#else
|
#else
|
||||||
# error Unknown architecture
|
# error Unknown architecture
|
||||||
#endif
|
#endif
|
||||||
@@ -908,34 +941,258 @@ std::string fs_get_cache_file(const std::string & filename) {
|
|||||||
return cache_directory + filename;
|
return cache_directory + filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
|
||||||
|
std::vector<common_file_info> files;
|
||||||
|
if (path.empty()) return files;
|
||||||
|
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||||
|
try {
|
||||||
|
// Only include regular files (skip directories)
|
||||||
|
const auto & p = entry.path();
|
||||||
|
if (std::filesystem::is_regular_file(p)) {
|
||||||
|
common_file_info info;
|
||||||
|
info.path = p.string();
|
||||||
|
info.name = p.filename().string();
|
||||||
|
info.is_dir = false;
|
||||||
|
try {
|
||||||
|
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
info.size = 0;
|
||||||
|
}
|
||||||
|
files.push_back(std::move(info));
|
||||||
|
} else if (include_directories && std::filesystem::is_directory(p)) {
|
||||||
|
common_file_info info;
|
||||||
|
info.path = p.string();
|
||||||
|
info.name = p.filename().string();
|
||||||
|
info.size = 0; // Directories have no size
|
||||||
|
info.is_dir = true;
|
||||||
|
files.push_back(std::move(info));
|
||||||
|
}
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
// skip entries we cannot inspect
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// TTY utils
|
||||||
|
//
|
||||||
|
|
||||||
|
bool tty_can_use_colors() {
|
||||||
|
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||||
|
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||||
|
if (no_color[0] != '\0') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check TERM environment variable
|
||||||
|
if (const char * term = std::getenv("TERM")) {
|
||||||
|
if (std::strcmp(term, "dumb") == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if stdout and stderr are connected to a terminal
|
||||||
|
// We check both because log messages can go to either
|
||||||
|
bool stdout_is_tty = isatty(fileno(stdout));
|
||||||
|
bool stderr_is_tty = isatty(fileno(stderr));
|
||||||
|
|
||||||
|
return stdout_is_tty || stderr_is_tty;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params) {
|
// TODO: move to common/sampling
|
||||||
common_init_result iparams;
|
static void common_init_sampler_from_model(
|
||||||
|
const llama_model * model,
|
||||||
|
common_params_sampling & sparams) {
|
||||||
|
|
||||||
|
const uint64_t config = sparams.user_sampling_config;
|
||||||
|
|
||||||
|
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
char buf[64] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
int32_t v = strtol(buf, &end, 10);
|
||||||
|
if (end && end != buf) {
|
||||||
|
dst = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
char buf[128] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
float v = strtof(buf, &end);
|
||||||
|
if (end && end != buf) {
|
||||||
|
dst = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sampling sequence
|
||||||
|
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||||
|
char buf[512] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||||
|
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||||
|
if (!sampler_names.empty()) {
|
||||||
|
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_init_result::impl {
|
||||||
|
impl() = default;
|
||||||
|
~impl() = default;
|
||||||
|
|
||||||
|
llama_model_ptr model;
|
||||||
|
llama_context_ptr context;
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> lora;
|
||||||
|
|
||||||
|
std::vector<common_sampler_ptr> samplers;
|
||||||
|
};
|
||||||
|
|
||||||
|
common_init_result::common_init_result(common_params & params) :
|
||||||
|
pimpl(new impl{}) {
|
||||||
auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
|
if (params.fit_params) {
|
||||||
|
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||||
|
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
return;
|
||||||
__func__, params.model.path.c_str());
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pimpl->model.reset(model);
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
// updates params.sampling
|
||||||
|
// TODO: fix naming
|
||||||
|
common_init_sampler_from_model(model, params.sampling);
|
||||||
|
|
||||||
|
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||||
|
params.sampling.ignore_eos = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize once
|
||||||
|
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||||
|
if (llama_vocab_is_eog(vocab, i)) {
|
||||||
|
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||||
|
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.ignore_eos) {
|
||||||
|
// add EOG biases to the active set of logit biases
|
||||||
|
params.sampling.logit_bias.insert(
|
||||||
|
params.sampling.logit_bias.end(),
|
||||||
|
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
//if (params.sampling.penalty_last_n == -1) {
|
||||||
|
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
//}
|
||||||
|
|
||||||
|
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||||
|
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
//}
|
||||||
|
|
||||||
|
pimpl->samplers.resize(cparams.n_seq_max);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||||
|
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||||
|
}
|
||||||
|
|
||||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||||
if (lctx == NULL) {
|
if (lctx == NULL) {
|
||||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
__func__, params.model.path.c_str());
|
return;
|
||||||
llama_model_free(model);
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pimpl->context.reset(lctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_model * common_init_result::model() {
|
||||||
|
return pimpl->model.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * common_init_result::context() {
|
||||||
|
return pimpl->context.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||||
|
return pimpl->samplers[seq_id].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||||
|
return pimpl->lora;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_init_result::free_context() {
|
||||||
|
pimpl->context.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||||
|
common_init_result_ptr res(new common_init_result(params));
|
||||||
|
|
||||||
|
llama_model * model = res->model();
|
||||||
|
if (model == NULL) {
|
||||||
|
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * lctx = res->context();
|
||||||
|
if (lctx == NULL) {
|
||||||
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
@@ -947,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
|
|
||||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||||
if (cvec.n_embd == -1) {
|
if (cvec.n_embd == -1) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = llama_apply_adapter_cvec(
|
int err = llama_apply_adapter_cvec(
|
||||||
@@ -961,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
params.control_vector_layer_start,
|
params.control_vector_layer_start,
|
||||||
params.control_vector_layer_end);
|
params.control_vector_layer_end);
|
||||||
if (err) {
|
if (err) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -988,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1001,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||||
if (lora == nullptr) {
|
if (lora == nullptr) {
|
||||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
char buf[1024];
|
char buf[1024];
|
||||||
@@ -1012,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
la.task_name = buf;
|
la.task_name = buf;
|
||||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||||
la.prompt_prefix = buf;
|
la.prompt_prefix = buf;
|
||||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.lora_init_without_apply) {
|
if (!params.lora_init_without_apply) {
|
||||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
|
||||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
|
||||||
params.sampling.ignore_eos = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize once
|
|
||||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
|
||||||
if (llama_vocab_is_eog(vocab, i)) {
|
|
||||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
|
||||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.ignore_eos) {
|
|
||||||
// add EOG biases to the active set of logit biases
|
|
||||||
params.sampling.logit_bias.insert(
|
|
||||||
params.sampling.logit_bias.end(),
|
|
||||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.penalty_last_n == -1) {
|
|
||||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
|
||||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.dry_penalty_last_n == -1) {
|
|
||||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
|
||||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.warmup) {
|
if (params.warmup) {
|
||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||||
|
|
||||||
@@ -1087,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
llama_set_warmup(lctx, false);
|
llama_set_warmup(lctx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
iparams.model.reset(model);
|
return res;
|
||||||
iparams.context.reset(lctx);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_init_result::~common_init_result() = default;
|
||||||
|
|
||||||
std::string get_model_endpoint() {
|
std::string get_model_endpoint() {
|
||||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||||
@@ -1101,7 +1316,9 @@ std::string get_model_endpoint() {
|
|||||||
std::string model_endpoint = "https://huggingface.co/";
|
std::string model_endpoint = "https://huggingface.co/";
|
||||||
if (endpoint_env) {
|
if (endpoint_env) {
|
||||||
model_endpoint = endpoint_env;
|
model_endpoint = endpoint_env;
|
||||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
if (model_endpoint.back() != '/') {
|
||||||
|
model_endpoint += '/';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return model_endpoint;
|
return model_endpoint;
|
||||||
}
|
}
|
||||||
|
|||||||
125
llama/llama.cpp/common/common.h
vendored
125
llama/llama.cpp/common/common.h
vendored
@@ -2,17 +2,19 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "ggml-opt.h"
|
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||||
#include "llama-cpp.h"
|
#define _WIN32_WINNT 0x0A00
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
@@ -28,7 +30,14 @@
|
|||||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||||
} while(0)
|
} while(0)
|
||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
struct common_time_meas {
|
||||||
|
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||||
|
~common_time_meas();
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
@@ -73,7 +82,8 @@ int32_t cpu_get_num_math();
|
|||||||
enum llama_example {
|
enum llama_example {
|
||||||
LLAMA_EXAMPLE_COMMON,
|
LLAMA_EXAMPLE_COMMON,
|
||||||
LLAMA_EXAMPLE_SPECULATIVE,
|
LLAMA_EXAMPLE_SPECULATIVE,
|
||||||
LLAMA_EXAMPLE_MAIN,
|
LLAMA_EXAMPLE_COMPLETION,
|
||||||
|
LLAMA_EXAMPLE_CLI,
|
||||||
LLAMA_EXAMPLE_EMBEDDING,
|
LLAMA_EXAMPLE_EMBEDDING,
|
||||||
LLAMA_EXAMPLE_PERPLEXITY,
|
LLAMA_EXAMPLE_PERPLEXITY,
|
||||||
LLAMA_EXAMPLE_RETRIEVAL,
|
LLAMA_EXAMPLE_RETRIEVAL,
|
||||||
@@ -89,6 +99,7 @@ enum llama_example {
|
|||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
LLAMA_EXAMPLE_DIFFUSION,
|
LLAMA_EXAMPLE_DIFFUSION,
|
||||||
LLAMA_EXAMPLE_FINETUNE,
|
LLAMA_EXAMPLE_FINETUNE,
|
||||||
|
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||||
|
|
||||||
LLAMA_EXAMPLE_COUNT,
|
LLAMA_EXAMPLE_COUNT,
|
||||||
};
|
};
|
||||||
@@ -133,6 +144,22 @@ struct common_grammar_trigger {
|
|||||||
llama_token token = LLAMA_TOKEN_NULL;
|
llama_token token = LLAMA_TOKEN_NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_params_sampling_config : uint64_t {
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
@@ -165,8 +192,9 @@ struct common_params_sampling {
|
|||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||||
|
|
||||||
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||||
@@ -188,6 +216,10 @@ struct common_params_sampling {
|
|||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||||
|
|
||||||
|
bool has_logit_bias() const {
|
||||||
|
return !logit_bias.empty();
|
||||||
|
}
|
||||||
|
|
||||||
// print the parameters into a string
|
// print the parameters into a string
|
||||||
std::string print() const;
|
std::string print() const;
|
||||||
};
|
};
|
||||||
@@ -198,6 +230,7 @@ struct common_params_model {
|
|||||||
std::string hf_repo = ""; // HF repo // NOLINT
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
std::string hf_file = ""; // HF file // NOLINT
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||||
|
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_speculative {
|
struct common_params_speculative {
|
||||||
@@ -274,8 +307,8 @@ struct lr_opt {
|
|||||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||||
int32_t n_ctx = 4096; // context size
|
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
@@ -296,9 +329,12 @@ struct common_params {
|
|||||||
// offload params
|
// offload params
|
||||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||||
|
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
|
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||||
|
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||||
|
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||||
|
|
||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
|
|
||||||
@@ -344,7 +380,7 @@ struct common_params {
|
|||||||
|
|
||||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||||
|
|
||||||
int32_t verbosity = 0;
|
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||||
bool offline = false;
|
bool offline = false;
|
||||||
@@ -378,6 +414,7 @@ struct common_params {
|
|||||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool show_timings = true; // show timing information on CLI
|
||||||
bool ctx_shift = false; // context shift on infinite text generation
|
bool ctx_shift = false; // context shift on infinite text generation
|
||||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
bool kv_unified = false; // enable unified KV cache
|
bool kv_unified = false; // enable unified KV cache
|
||||||
@@ -406,6 +443,8 @@ struct common_params {
|
|||||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||||
bool no_mmproj = false; // explicitly disable multimodal model
|
bool no_mmproj = false; // explicitly disable multimodal model
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
int image_min_tokens = -1;
|
||||||
|
int image_max_tokens = -1;
|
||||||
|
|
||||||
// finetune
|
// finetune
|
||||||
struct lr_opt lr;
|
struct lr_opt lr;
|
||||||
@@ -432,7 +471,7 @@ struct common_params {
|
|||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
std::string api_prefix = ""; // NOLINT
|
std::string api_prefix = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = true; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
int reasoning_budget = -1;
|
int reasoning_budget = -1;
|
||||||
@@ -451,14 +490,22 @@ struct common_params {
|
|||||||
bool endpoint_props = false; // only control POST requests, not GET
|
bool endpoint_props = false; // only control POST requests, not GET
|
||||||
bool endpoint_metrics = false;
|
bool endpoint_metrics = false;
|
||||||
|
|
||||||
|
// router server configs
|
||||||
|
std::string models_dir = ""; // directory containing models for the router server
|
||||||
|
std::string models_preset = ""; // directory containing model presets for the router server
|
||||||
|
int models_max = 4; // maximum number of models to load simultaneously
|
||||||
|
bool models_autoload = true; // automatically load models when requested via the router server
|
||||||
|
|
||||||
bool log_json = false;
|
bool log_json = false;
|
||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
std::string media_path; // path to directory for loading media files
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
bool is_tg_separate = false;
|
||||||
|
|
||||||
std::vector<int32_t> n_pp;
|
std::vector<int32_t> n_pp;
|
||||||
std::vector<int32_t> n_tg;
|
std::vector<int32_t> n_tg;
|
||||||
@@ -505,6 +552,10 @@ struct common_params {
|
|||||||
// return false from callback to abort model loading or true to continue
|
// return false from callback to abort model loading or true to continue
|
||||||
llama_progress_callback load_progress_callback = NULL;
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
void * load_progress_callback_user_data = NULL;
|
void * load_progress_callback_user_data = NULL;
|
||||||
|
|
||||||
|
bool has_speculative() const {
|
||||||
|
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
// call once at the start of a program if it uses libcommon
|
||||||
@@ -599,25 +650,55 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
|
||||||
bool fs_validate_filename(const std::string & filename);
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||||
bool fs_create_directory_with_parents(const std::string & path);
|
bool fs_create_directory_with_parents(const std::string & path);
|
||||||
|
bool fs_is_directory(const std::string & path);
|
||||||
|
|
||||||
std::string fs_get_cache_directory();
|
std::string fs_get_cache_directory();
|
||||||
std::string fs_get_cache_file(const std::string & filename);
|
std::string fs_get_cache_file(const std::string & filename);
|
||||||
|
|
||||||
|
struct common_file_info {
|
||||||
|
std::string path;
|
||||||
|
std::string name;
|
||||||
|
size_t size = 0; // in bytes
|
||||||
|
bool is_dir = false;
|
||||||
|
};
|
||||||
|
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||||
|
|
||||||
|
//
|
||||||
|
// TTY utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// Auto-detect if colors can be enabled based on terminal and environment
|
||||||
|
bool tty_can_use_colors();
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
// note: defines object's lifetime
|
struct common_sampler;
|
||||||
struct common_init_result {
|
|
||||||
llama_model_ptr model;
|
|
||||||
llama_context_ptr context;
|
|
||||||
|
|
||||||
std::vector<llama_adapter_lora_ptr> lora;
|
// note: defines the model, context, samplers, ets. lifetimes
|
||||||
|
struct common_init_result {
|
||||||
|
common_init_result(common_params & params);
|
||||||
|
~common_init_result();
|
||||||
|
|
||||||
|
llama_model * model();
|
||||||
|
llama_context * context();
|
||||||
|
common_sampler * sampler(llama_seq_id seq_id);
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> & lora();
|
||||||
|
|
||||||
|
void free_context();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params);
|
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||||
|
|
||||||
|
common_init_result_ptr common_init_from_params(common_params & params);
|
||||||
|
|
||||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||||
|
|||||||
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||||
@@ -303,8 +303,11 @@ static std::string format_literal(const std::string & literal) {
|
|||||||
return "\"" + escaped + "\"";
|
return "\"" + escaped + "\"";
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchemaConverter {
|
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||||
|
|
||||||
|
class common_schema_converter {
|
||||||
private:
|
private:
|
||||||
|
friend class common_schema_info;
|
||||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
@@ -601,7 +604,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string _resolve_ref(const std::string & ref) {
|
std::string _resolve_ref(const std::string & ref) {
|
||||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
auto it = ref.find('#');
|
||||||
|
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||||
|
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||||
|
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||||
_refs_being_resolved.insert(ref);
|
_refs_being_resolved.insert(ref);
|
||||||
json resolved = _refs[ref];
|
json resolved = _refs[ref];
|
||||||
@@ -724,7 +730,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SchemaConverter(
|
common_schema_converter(
|
||||||
const std::function<json(const std::string &)> & fetch_json,
|
const std::function<json(const std::string &)> & fetch_json,
|
||||||
bool dotall)
|
bool dotall)
|
||||||
: _fetch_json(fetch_json), _dotall(dotall)
|
: _fetch_json(fetch_json), _dotall(dotall)
|
||||||
@@ -774,11 +780,24 @@ public:
|
|||||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_object() && target.contains(sel)) {
|
||||||
|
target = target[sel];
|
||||||
|
} else if (target.is_array()) {
|
||||||
|
size_t sel_index;
|
||||||
|
try {
|
||||||
|
sel_index = std::stoul(sel);
|
||||||
|
} catch (const std::invalid_argument & e) {
|
||||||
|
sel_index = target.size();
|
||||||
|
}
|
||||||
|
if (sel_index >= target.size()) {
|
||||||
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
target = target[sel_index];
|
||||||
|
} else {
|
||||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
target = target[sel];
|
|
||||||
}
|
}
|
||||||
_refs[ref] = target;
|
_refs[ref] = target;
|
||||||
}
|
}
|
||||||
@@ -956,7 +975,7 @@ public:
|
|||||||
|
|
||||||
void check_errors() {
|
void check_errors() {
|
||||||
if (!_errors.empty()) {
|
if (!_errors.empty()) {
|
||||||
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||||
}
|
}
|
||||||
if (!_warnings.empty()) {
|
if (!_warnings.empty()) {
|
||||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||||
@@ -972,6 +991,134 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// common_schema_info implementation (pimpl)
|
||||||
|
|
||||||
|
common_schema_info::common_schema_info()
|
||||||
|
: impl_(std::make_unique<common_schema_converter>(
|
||||||
|
[](const std::string &) { return json(); },
|
||||||
|
false)) {}
|
||||||
|
|
||||||
|
common_schema_info::~common_schema_info() = default;
|
||||||
|
|
||||||
|
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
|
||||||
|
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
|
||||||
|
|
||||||
|
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
|
||||||
|
impl_->resolve_refs(schema, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines if a JSON schema can resolve to a string type through any path.
|
||||||
|
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
|
||||||
|
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
|
||||||
|
// true, allowing callers to handle the value as a raw string for simplicity.
|
||||||
|
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
|
||||||
|
std::unordered_set<std::string> visited_refs;
|
||||||
|
|
||||||
|
std::function<bool(const json &)> check = [&](const json & s) -> bool {
|
||||||
|
if (!s.is_object()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle $ref
|
||||||
|
if (s.contains("$ref")) {
|
||||||
|
const std::string & ref = s["$ref"];
|
||||||
|
if (visited_refs.find(ref) != visited_refs.end()) {
|
||||||
|
// Circular reference, assume not a string to be safe
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
visited_refs.insert(ref);
|
||||||
|
auto it = impl_->_refs.find(ref);
|
||||||
|
if (it != impl_->_refs.end()) {
|
||||||
|
return check(it->second);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check type field
|
||||||
|
if (s.contains("type")) {
|
||||||
|
const json & schema_type = s["type"];
|
||||||
|
if (schema_type.is_string()) {
|
||||||
|
if (schema_type == "string") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (schema_type.is_array()) {
|
||||||
|
// Type can be an array like ["string", "null"]
|
||||||
|
for (const auto & t : schema_type) {
|
||||||
|
if (t == "string") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check oneOf/anyOf - if any alternative can be a string
|
||||||
|
if (s.contains("oneOf")) {
|
||||||
|
for (const auto & alt : s["oneOf"]) {
|
||||||
|
if (check(alt)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (s.contains("anyOf")) {
|
||||||
|
for (const auto & alt : s["anyOf"]) {
|
||||||
|
if (check(alt)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allOf - all components must be compatible with string type
|
||||||
|
if (s.contains("allOf")) {
|
||||||
|
bool all_string = true;
|
||||||
|
for (const auto & component : s["allOf"]) {
|
||||||
|
if (!check(component)) {
|
||||||
|
all_string = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (all_string) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check const - if the constant value is a string
|
||||||
|
if (s.contains("const")) {
|
||||||
|
if (s["const"].is_string()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check enum - if any enum value is a string
|
||||||
|
if (s.contains("enum")) {
|
||||||
|
for (const auto & val : s["enum"]) {
|
||||||
|
if (val.is_string()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String-specific keywords imply string type
|
||||||
|
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check format - many formats imply string
|
||||||
|
if (s.contains("format")) {
|
||||||
|
const std::string & fmt = s["format"];
|
||||||
|
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
|
||||||
|
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
|
||||||
|
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
|
||||||
|
fmt.find("uuid") == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
return check(schema);
|
||||||
|
}
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
if (!force_gbnf) {
|
if (!force_gbnf) {
|
||||||
@@ -988,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||||
common_grammar_builder builder {
|
common_grammar_builder builder {
|
||||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||||
return converter._add_rule(name, rule);
|
return converter._add_rule(name, rule);
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user