Compare commits

...

89 Commits

Author SHA1 Message Date
ParthSareen
92af238208 wip 2025-12-02 12:17:36 -08:00
ParthSareen
7461faf651 script to render templates 2025-12-01 18:03:04 -08:00
Daniel Hiltgen
554172759c win: warn if ggml-base detected in PATH (#13289)
If the user has somehow installed another GGML based app which places a
ggml-base lib somewhere in their PATH, we can experience runtime problems
due to incompatibilities.  This change adds a warning message if we detect
a ggml-base outside of our install location to aid in troubleshooting.
2025-12-01 15:36:47 -08:00
Bruce MacDonald
5b6a8e6001 api/client: handle non-json streaming errors (#13007)
While processing the response stream during a chat or generation if an error is occurred it is parsed and returned to the user. The issue with the existing code is that this assumed the response would be valid JSON, which is not a safe assumption and caused cryptic error messages to be displayed due to parsing failures:
`invalid character 'i' looking for beginning of value`

This change updates the stream function to return the raw error string if it cant be parsed as JSON. This should help with debugging issues by making sure the actual error reaches the user.
2025-12-01 15:10:16 -08:00
Daniel Hiltgen
467bbc0dd5 jetpack: require exact match or skip cuda_jetpack* (#13288)
The cuda_jetpack libs will enumerate discrete GPUs on SBSA systems
which leads to runtime failures of missing kernels.  This fix
requires an exact match to enable jetpacks instead of relying on
enumeration to filter out supported libraries.
2025-12-01 12:48:16 -08:00
Jeffrey Morgan
6d9f9323c5 .gitattributes: add app/webview to linguist-vendored (#13274) 2025-11-29 23:46:10 -05:00
Ondrej Kokes
0c2489605d docs: fix output formatting in faq.mdx (#13231)
There were a few Markdown typos in one FAQ answer. It now renders as a proper ascii table.
2025-11-28 19:19:21 -05:00
EntropyYue
8b1b89a984 docs: remove deprecated parameters (#13237) 2025-11-26 11:03:09 +09:00
Eva H
47e272c35a app/cmd: update ollama help to navigate to ollama doc instead of github page (#13174) 2025-11-20 16:30:35 -05:00
Jeffrey Morgan
417a81fda3 app: open app instead of always navigating to / on connect (#13164) 2025-11-20 12:59:17 -08:00
Daniel Hiltgen
dba62ff3a5 discovery: fix cuda overlap case (#13176)
Recent refactoring introduced a regression for filtering cuda overlap to favor newest supported version.
2025-11-20 12:15:37 -08:00
Grace
d70e935526 Parser for Cogito v2 (#13145) 2025-11-19 17:21:07 -08:00
Michael Yang
5c1063df7f deepseek2: upgrade to run v3+ models (#13166)
the check for mla omits v3 and r1 which should not return unsupported.
instead check the tokenizer for compatibility
2025-11-19 17:05:39 -08:00
Jesse Gross
cb485b2019 kvcache: Run tests both with and without PermutedV
The causal cache can store data differently depending on what is
best for the backend. We should run tests both ways.
2025-11-19 16:45:30 -08:00
nicole pardal
b2af50960f nomic-embed: nomic-embed-text defaulted to ollama runner (#13144) 2025-11-19 13:03:44 -08:00
Michael Yang
eac5b8bfbd chore: mark vulkan shaders as vendored files 2025-11-19 12:01:23 -08:00
Patrick Devine
604e43b28d models: enable deepseek2 (deepseek v3.1 w/ MLA) on the new engine (#13151) 2025-11-18 22:03:50 -08:00
Jesse Gross
53985b3c4d kvcache: Use SetRows to store cache data
We currently copy data into the KV cache in contiguous buffers using
ggml_cpy(). ggml_set_rows() was introduced to allow scatter operation
so that contiguous buffers are no longer required. The direct primary
benefit of this is that we no longer need to perform defragmentation.

However, GGML recently removed an optimization for ggml_cpy() and
we picked it up in 544b673 "ggml update to b6840 (#12791)". This
caused a roughly 40% drop in token generation performance on CUDA
due to CUDA graphs no longer being used. By switching to
ggml_set_rows(), the original optimization is no longer necessary
and CUDA performance is restored.

Fixes #13112
2025-11-18 20:42:28 -08:00
Jesse Gross
b6e02cbbd2 ggml: Automatically make tensors contiguous on reshape
GGML requires tensors to be contiguous for reshape and if
this is not the case, it will assert fail. Contiguous is an
expensive operation, so it's best to do it lazily when it is
actually required rather than ahead of time when it may not
be needed.
2025-11-18 20:42:28 -08:00
Grace
91935631ac Renderer for Cogito v2 (#13139) 2025-11-18 19:06:34 -08:00
nicole pardal
8de30b568a nomic-embed-text model implementation (#13071) 2025-11-18 18:28:10 -08:00
Daniel Hiltgen
485da9fd35 win: exit instead of abort (#13138)
Calling abort on windows triggers the C++ runtime to attempt a debugger
attach, which causes the crashed runners to hang instead of exit, leading
to a timeout instead of a fast failure during discovery.
2025-11-18 16:33:33 -08:00
Michael Yang
0796d79d19 cuda: skip large batches
cuda panics on batches larger than 1024 so skip those and fallback to
cpu
2025-11-18 16:11:37 -08:00
Michael Yang
92981ae3f2 deepseekocr 2025-11-18 16:11:37 -08:00
Lhiam Andrei Lingco
8ed1adf3db docs: fix typo in vscode.mdx (#13116) 2025-11-18 13:18:42 -08:00
Michael Yang
440a3823a6 fix(tokenizer): add special tokens to empty inputs (#13091) 2025-11-18 11:16:56 -08:00
Michael Yang
718961de68 migrate to golangci-lint v2 (#13109)
* migrate to golangci-lint v2
* copyloopvar
2025-11-18 11:00:26 -08:00
SamareshSingh
330f62a7fa docs: add Void Editor to community integrations (#13124)
Void is an open source AI code editor and Cursor alternative that supports
Ollama. It's built on VS Code and allows users to connect directly to Ollama
for private LLM usage without going through a middleman backend.

Key features:
- Open source Cursor alternative
- Direct Ollama integration
- VS Code fork with full compatibility
- Agent mode and MCP support
- Works with any open source model

Fixes #12919

Signed-off-by: Samaresh Kumar Singh <ssam3003@gmail.com>
2025-11-17 19:20:36 -08:00
Grace
584e2d646f Add deepseek v3.1 (#13063)
* Add mla for flash attention
* Revert to using chunks
2025-11-17 18:03:21 -08:00
Eva H
1fd4cb87b2 app/cmd: restrict ollama:// URL scheme to supported paths (#13120) 2025-11-17 20:10:45 -05:00
Cerussite
4aba2e8b72 discover: Support cgroups cores and memory limitations (#10292)
* Add supports for cgroups cores and memory limitations

* fix compile error and add logs

* remove cpu info log
2025-11-17 16:13:03 -08:00
Daniel Hiltgen
2f36d769aa bring back sysfs based VRAM information for AMD (#12871)
* build: optimize dockerfile context for iterating

This moves the copy of the source into the layer AFTER
doing software installs so we don't have to go through
the RPM install for cuda, etc. every time you touch a
source file.

* amd: implement linux sysfs based VRAM lookup

This adds a C++ implementation of sysfs DRM VRAM discovery
for more accurate free VRAM data on linux for AMD GPUs.
2025-11-17 15:40:58 -08:00
Daniel Hiltgen
399eacf486 ci: fix missing vulkan binaries in linux bundles (#13123) 2025-11-17 15:39:59 -08:00
Eva H
231cc878cb app/ui: fix to point ollama client to ui backend in dev mode (#13079) 2025-11-17 12:58:35 -05:00
Jeffrey Morgan
aa676b313f docs: link to ollama.com instead of hardcoding list of cloud models (#13110) 2025-11-16 20:56:09 -08:00
omahs
dd0ed0ef17 docs: fix typos in repository documentation (#10683) 2025-11-15 20:22:29 -08:00
Joel Bryan Juliano
d5649821ae readme: add Kdeps to community integrations (#11877)
Kdeps is an AI framework for building Dockerized full-stack AI
applications declaratively and uses Ollama LLM models on the
backend
2025-11-15 19:19:03 -08:00
pierwill
4cea757e70 server: clean up manifest documentation (#12995)
Co-authored-by: pierwill <pierwill@users.noreply.github.com>
2025-11-15 19:13:15 -08:00
Vignesh Skanda
a751bc159c llama: test case typo and readability improvements (#13078) 2025-11-15 18:54:27 -08:00
Laurențiu Nicola
5d31242fbf discover: fix typos in runner.go (#13096) 2025-11-15 18:52:54 -08:00
Patrick Devine
d7fd72193f tests: basic benchmarking test framework (#12964)
This change adds a basic benchmarking test framework for Ollama which can
be used to determine the prefill, eval, load duration, and total duration
for running a given model or models.
2025-11-15 18:17:40 -08:00
Daniel Hiltgen
72ff5b9d8c log: warn if user overrides detected (#13088)
Many failed GPU discovery issues recently can be traced to incorrect override settings.
This extra logging should help quickly spot these and guide users to try unsetting them first.
2025-11-14 14:36:28 -08:00
Parth Sareen
ce29f695b4 docs: add logprobs to openapi (#13090) 2025-11-14 14:14:58 -08:00
Michael Yang
12b174b10e fix tensor merge (#13053) 2025-11-13 15:32:34 -08:00
Michael Yang
333203d871 chore: update models to use slice/chunk/chunksections (#12934)
* use slice/chunks

* bert

* llama4

* gemma3n

* gptoss

* mistral3

* qwen3vl

* qwen25vl

* deepseek2

* remove unused ops
2025-11-13 15:20:12 -08:00
Parth Sareen
c114987523 logprob: add bytes to logprobs (#13068) 2025-11-13 13:49:25 -08:00
Michael Yang
b48083f33f ml: add slice operation (#12870)
* slice

* chunk, chunksections
2025-11-13 13:28:21 -08:00
nicole pardal
482bec824f embeddings: added cli command to embedding docs (#12993) 2025-11-13 13:24:13 -08:00
Kowyo
684a9a8c5a docs: fix typo (VSCode -> VS Code) (#13072) 2025-11-12 20:49:33 -08:00
Jeffrey Morgan
54a76d3773 app: remove source code for previous JavaScript-based macOS app (#13067)
The code in this directory has been replaced with the
new Go version in the 'app' directory.
2025-11-12 20:37:43 -08:00
Radhi
8a75d8b015 readme: add AI UI to community integrations (#13035) 2025-11-12 17:08:50 -08:00
Jeffrey Morgan
f206357412 readme: fix incorrect header in community integrations (#13065) 2025-11-12 17:00:16 -08:00
Daniel Hiltgen
8224cd9063 ci: fix win vulkan (#13062) 2025-11-12 10:32:24 -08:00
Daniel Hiltgen
6286d9a3a5 Enable Vulkan with a temporary opt-in setting (#12931)
* docs: vulkan information

* Revert "CI: Set up temporary opt-out Vulkan support (#12614)"

This reverts commit 8b6e5baee7.

* vulkan: temporary opt-in for Vulkan support

Revert this once we're ready to enable by default.

* win: add vulkan CI build
2025-11-12 08:40:38 -08:00
Daniel Hiltgen
3a9e8e9fd4 vulkan: temporary cary of vulkan fixes (#12971)
This should be reverted once we update ggml past b6897
2025-11-12 08:31:40 -08:00
Jeffrey Morgan
cb1cb06478 docs: rename api-reference.md back to api.md since redirect stopped working (#13056) 2025-11-11 15:53:06 -08:00
Jeffrey Morgan
2d5e066c8c docs: fix openapi.yaml warnings, rename api.md to api-reference.md (#12904) 2025-11-11 15:39:35 -08:00
Bruce MacDonald
15968714bd docs/openapi: document that delete and copy responses are empty (#13055)
Some route endpoints return an empty response with a 200 OK. These should be documented in the OpenAPI doc. Note that the previous deletion response was not correct.
2025-11-11 15:07:21 -08:00
Jesse Gross
8bf38552de llm: Prefer dedicated GPUs over iGPUs when allocating memory
We currently assign model layers to GPUs according to free VRAM,
which assumes that GPU performance is roughly equal. This does not
work well for mixed dGPU and iGPU systems because iGPUs typically
use system memory which is large but their performance is slow.
This instead assigns layers to dGPUs first and then iGPUs.

In the future, this could be generalized to have a more fine grained
notion of GPU performance but dGPU vs. iGPU performance is the most
extreme.
2025-11-11 13:11:08 -08:00
Jesse Gross
b13fbad0fe llm: Separate llamaServer and ollamaServer code paths
Originally, llamaServer represented old memory estimates, which
could be used with either the old or new engine. ollamaServer was
used only for the new estimates and new engine. Since these
implementations did not map directly to engine, there was engine-
specific code in common code paths.

Now that new estimates are always used for the new engine, there is
a direct mapping between server type and engine. This separates out
most of the engine-specific code into the correct implementation
to make things easier to understand.
2025-11-11 13:11:08 -08:00
Jesse Gross
f560bd077f llm: Use Ollama engine memory layouts for both old and new engines
Currently for both the old and new engines, there is code to
calculate how much memory is required for a model and lay out
the layers onto GPUs. This reuses the new engine's lay out code
for the old engine as well, bringing them closer together. The
old engine continues to use its current method of estimating
required memory.

This reduces maintainence effort and improves consistency, as new
features only need to be implemented in one place. The newer code
is also more accurate, especially with multiple GPUs.
2025-11-11 13:11:08 -08:00
Jesse Gross
4372d0bfef llamarunner: Respect device ordering for offloaded layers
We used to control the way that llama.cpp saw devices using
CUDA_VISIBLE_DEVICES or similar. This would ensure that the layers
offloaded to a device were actually the ones intended. This is
particularly important because we might reorder devices based on
free memory or performance.

When we started explicitly scheduling layers, this logic went
away but the llamarunner didn't have any way to set the correct
order of devices. This meant that the correct number of layers
would be assigned to a device but not necessarily the layers
that were expected. This change sets up the devices correctly
based on the offload information.
2025-11-11 13:11:08 -08:00
Eva H
31361c4d3c app/ui: do not send thinking to prevent errors with cloud provider 2025-11-11 16:09:24 -05:00
Baptiste Jamin
59241c5bee server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's
OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter
in the API, Ollama will return the log probabilities for each token generated.
'top_logprobs', an integer value can also be specified up to the value 20.
When specified, the API will also provide the number of most likely tokens to
return at each token position

Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
2025-11-11 08:49:50 -08:00
Eva Ho
2a9b61f099 address comment 2025-11-11 08:58:55 -05:00
Sheikh
6df4208836 docs: fix metal gpu section header (#13045) 2025-11-10 21:51:22 -08:00
Eva Ho
9d615cdaa0 fix test 2025-11-10 20:13:50 -05:00
Eva Ho
6a818b8a09 clean up 2025-11-10 19:08:42 -05:00
Eva Ho
2aaf29acb5 app/ui: do not send to prevent errors with cloud provider 2025-11-10 19:05:00 -05:00
Eva H
a42f826acb app/ui: using streamdown AI elements for markdown rendering 2025-11-10 12:05:59 -05:00
Bruce MacDonald
e10a3533a5 app/docs: remove out of date storybook instructions (#13006) 2025-11-08 13:28:18 -08:00
Patrick Devine
91ec3ddbeb bugfix: don't include both consolidated.safetensors and model-*.safetensors (#13010) 2025-11-07 22:41:57 -08:00
Parth Sareen
755ac3b069 docs: update n8n URL for Ollama (#12994) 2025-11-07 20:07:26 -08:00
Daniel Hiltgen
60b8973559 doc: re-add login autostart faq and GPU updates (#12975)
* doc: re-add login autostart faq

This appears to have been accidentally dropped during the doc migration.

* docs: GPU updates lost on the doc update

* review comments: improve windows login disable instructions
2025-11-07 11:21:44 -08:00
Tomoya Fujita
d2ef679d42 docs: fix 404 link to modelfile documentation (#12996) 2025-11-07 10:06:46 -08:00
Thomas Stocker
d4e0da0890 Remove unnecessary MacOs 13 and lower Patches (#12656)
* Remove unnecessary macos 13 Patch

* Remove unnecessary MacOs Version Guard patch

* rename patchesw

* remove again macos13 patch

* rename files
2025-11-06 15:52:56 -08:00
Jeffrey Morgan
565b802a6b openai: fix tool call ID mapping (#12988) 2025-11-06 15:26:25 -08:00
Saifeddine ALOUI
6c79e6c09a readme: add security tools section and Ollama fortress to community integrations (#12981) 2025-11-06 15:21:13 -08:00
breatn
780762f9d2 server: fix duplicate 'is' typo in comment (#12985) 2025-11-06 14:44:44 -08:00
Jeffrey Morgan
30fcc71983 api: add omitempty to required tool function parameter type (#12989) 2025-11-06 14:08:55 -08:00
Eva Ho
3501a4bdf9 address comment 2025-11-06 16:49:22 -05:00
Eva H
73a0cafc1e Merge pull request #12973 from macarronesc/main
feat: add support for WebP images in Ollama's app
2025-11-06 16:31:46 -05:00
Eva Ho
e309c80474 address comments 2025-11-06 13:49:59 -05:00
Daniel Alejandro Coll Tejeda
a4a53692f8 refactor: remove GIF support from image validation tests and logging 2025-11-06 09:09:51 +00:00
Eva Ho
536c987c39 address comment 2025-11-05 20:19:34 -05:00
Eva Ho
a534d4e9e1 fixing thinking not scrolling issue 2025-11-05 16:06:55 -05:00
Eva Ho
74586aa9df address comments 2025-11-05 16:06:55 -05:00
Eva Ho
8c74f5ddfd ui: using streamdown AI elements for markdown rendering 2025-11-05 16:06:55 -05:00
Daniel Alejandro Coll Tejeda
bddfa2100f feat: add support for WebP images in Ollama's app 2025-11-05 21:23:20 +01:00
216 changed files with 17134 additions and 23170 deletions

4
.gitattributes vendored
View File

@@ -15,8 +15,12 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -104,6 +104,13 @@ jobs:
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: 'rocm'
- os: windows
arch: amd64
preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -113,13 +120,14 @@ jobs:
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
name: Install CUDA ${{ matrix.cuda-version }}
@@ -149,6 +157,18 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -159,6 +179,7 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
@@ -171,7 +192,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
@@ -312,13 +333,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive_novulkan
target: archive
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive_novulkan
target: archive
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -345,6 +366,7 @@ jobs:
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -374,14 +396,12 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -394,14 +414,6 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -419,7 +431,6 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.preset }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest

View File

@@ -172,6 +172,7 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
@@ -225,12 +226,9 @@ jobs:
if: always()
run: go test -count=1 -benchtime=1x ./...
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
- uses: golangci/golangci-lint-action@v9
with:
args: --timeout 10m0s -v
only-new-issues: true
patches:
runs-on: ubuntu-latest
@@ -239,4 +237,4 @@ jobs:
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
git diff --compact-summary --exit-code

View File

@@ -1,41 +1,77 @@
run:
timeout: 5m
version: "2"
linters:
default: none
enable:
- asasalint
- bidichk
- bodyclose
- containedctx
- copyloopvar
- errcheck
- errorlint
- exptostd
- gocheckcompilerdirectives
- gofmt
- gofumpt
- gosimple
- gocritic
- govet
- ineffassign
- intrange
- makezero
- misspell
- modernize
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- perfsprint
- prealloc
- sloglint
- staticcheck
- unconvert
- unused
- usestdlibvars
- usetesting
- wastedassign
- whitespace
disable:
- usestdlibvars
- errcheck
linters-settings:
staticcheck:
checks:
- all
- -SA1019 # omit Deprecated check
severity:
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
severity: info
settings:
errcheck:
exclude-functions:
- fmt.Fprintf
perfsprint:
strconcat: false
concat-loop: false
staticcheck:
checks:
- all
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- -SA1019
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- -ST1000
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- -ST1003
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- -ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- -ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- -ST1022
usestdlibvars:
http-method: false
http-status-code: false
formatters:
enable:
- gci
- gofmt
- gofumpt
settings:
gci:
sections:
- standard
- default
- localmodule

View File

@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
### Issues that may not be accepted
@@ -43,7 +43,7 @@ Tips for proposals:
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the change were accepted.
see if the changes were accepted.
## Pull requests
@@ -66,7 +66,6 @@ Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad
docs: simplify manual installation with shorter curl commands
Bad Examples:

View File

@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
@@ -57,6 +57,8 @@ ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
@@ -67,6 +69,8 @@ ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
@@ -78,6 +82,8 @@ ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
@@ -87,6 +93,8 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
@@ -118,6 +126,8 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
@@ -159,32 +169,7 @@ ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \

View File

@@ -299,6 +299,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [AI-UI](https://github.com/bajahaw/ai-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -365,7 +366,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -397,7 +399,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
@@ -426,6 +428,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
@@ -615,7 +618,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
@@ -623,7 +626,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
@@ -633,9 +636,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -14,7 +14,7 @@ Please include the following details in your report:
## Security best practices
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
- Regularly updating to the latest version of Ollama
- Securing access to hosted instances of Ollama

View File

@@ -226,7 +226,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -55,6 +55,7 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -111,6 +112,20 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -135,6 +150,12 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -173,9 +194,10 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
name string
response any
wantErr string
wantStatusCode int
}{
{
name: "immediate error response",
@@ -183,7 +205,8 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
},
{
name: "server error response",
@@ -191,7 +214,8 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "successful response",
@@ -203,6 +227,26 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -210,11 +254,16 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
}
return
}
@@ -241,6 +290,15 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -117,6 +117,14 @@ type GenerateRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -159,6 +167,14 @@ type ChatRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
type Tools []Tool
@@ -323,7 +339,7 @@ type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Required []string `json:"required,omitempty"`
Properties map[string]ToolProperty `json:"properties"`
}
@@ -343,6 +359,27 @@ func (t *ToolFunction) String() string {
return string(bts)
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
// Token is the text representation of the token.
Token string `json:"token"`
// Logprob is the log probability of this token.
Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token
Bytes []int `json:"bytes,omitempty"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
// TopLogprobs contains the most likely tokens and their log probabilities
// at this position, if requested via TopLogprobs parameter.
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
@@ -369,6 +406,10 @@ type ChatResponse struct {
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
Metrics
}
@@ -677,6 +718,10 @@ type GenerateResponse struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -298,6 +298,44 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) {
}
}
func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input ToolFunctionParameters
expected string
}{
{
name: "simple object with string property",
input: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
},
{
name: "no required",
input: ToolFunctionParameters{
Type: "object",
Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
},
},
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
require.NoError(t, err)
assert.Equal(t, test.expected, string(data))
})
}
}
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",

View File

@@ -48,16 +48,6 @@ The `-dev` flag enables:
- CORS headers for cross-origin requests
- Hot-reload support for UI development
#### Run Storybook
Inside the `ui/app` directory, run:
```bash
npm run storybook
```
For now we're writing stories as siblings of the component they're testing. So for example, `src/components/Message.stories.tsx` is the story for `src/components/Message.tsx`.
## Build

View File

@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
return
}
@@ -434,37 +434,30 @@ func openInBrowser(url string) {
}
}
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
// parseURLScheme parses an ollama:// URL and validates it
// Supports: ollama:// (open app) and ollama://connect (OAuth)
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
parsedURL, err := url.Parse(urlSchemeRequest)
if err != nil {
return false, "", err
return false, fmt.Errorf("invalid URL: %w", err)
}
// Check if this is a connect URL
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
return true, "", nil
return true, nil
}
// Extract the UI path
path := "/"
if parsedURL.Path != "" && parsedURL.Path != "/" {
// For URLs like ollama:///settings, use the path directly
path = parsedURL.Path
} else if parsedURL.Host != "" {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
// This also handles ollama://settings/ where Windows adds a trailing slash
path = "/" + parsedURL.Host
// Allow bare ollama:// or ollama:/// to open the app
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
return false, nil
}
return false, path, nil
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
}
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
isConnect, err := parseURLScheme(urlSchemeRequest)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
return
@@ -473,6 +466,8 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage(uiPath)
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -24,27 +24,14 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -260,7 +247,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
// handleURLSchemeRequest processes URL scheme requests from other instances
func handleURLSchemeRequest(urlScheme string) {
isConnect, uiPath, err := parseURLScheme(urlScheme)
isConnect, err := parseURLScheme(urlScheme)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
return
@@ -147,7 +147,9 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage(uiPath)
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -282,7 +282,7 @@ func (w *Webview) Run(path string) unsafe.Pointer {
"go", "rs", "swift", "kt", "scala", "sh", "bat", "yaml", "yml", "toml", "ini",
"cfg", "conf", "log", "rtf",
}
imageExts := []string{"png", "jpg", "jpeg"}
imageExts := []string{"png", "jpg", "jpeg", "webp"}
allowedExts := append(textExts, imageExts...)
// Use native multiple file selection with extension filtering

View File

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-math": "^6.0.0",
"streamdown": "^1.4.0",
"unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0"
},

View File

@@ -15,6 +15,7 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,8 +28,6 @@ Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -205,6 +204,13 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data),
}));
// Only send think parameter when actually requesting thinking
// Don't send false as it causes issues with some providers
const shouldSendThink =
think !== undefined &&
((typeof think === "boolean" && think) ||
(typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST",
headers: {
@@ -222,7 +228,7 @@ export async function* sendMessage(
web_search: webSearch ?? false,
file_tools: fileTools ?? false,
...(forceUpdate !== undefined ? { forceUpdate } : {}),
...(think !== undefined ? { think } : {}),
...(shouldSendThink ? { think } : {}),
}),
),
signal,

View File

File diff suppressed because one or more lines are too long

View File

@@ -1,522 +0,0 @@
import { expect, test, suite } from "vitest";
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
suite("common llm outputs that cause issues", () => {
test("prefix of bolded list item shouldn't make a horizontal line", () => {
// we're going to go in order of incrementally adding characters. This
// happens really commonly with LLMs that like to make lists like so:
//
// * **point 1**: explanatory text
// * **point 2**: more explanatory text
//
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
// (C) is a total mess. (A) renders as a single bullet point in an
// otherwise empty list, (B) renders as two nested lists (and therefore
// two bullet points, styled differently by default in html), and (C)
// renders as a horizontal line because in markdown apparently `***` or `*
// * *` horizontal rules don't have as strict whitespace rules as I
// expected them to
// these are alone (i.e., they would be the first list item)
expect(processStreamingMarkdown("*")).toBe("");
expect(processStreamingMarkdown("* *")).toBe("");
expect(processStreamingMarkdown("* **")).toBe("");
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
// with a list item before them
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"*"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* *"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* **"
].join("\n"),
),
).toBe("* abc");
});
test("bolded list items with text should be rendered properly", () => {
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
});
test("partially bolded list items should be autoclosed", () => {
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
});
suite(
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
() => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`*"),
).toBe("* **Asynchronous Function `async`**");
});
},
);
});
suite("autoclosing bold", () => {
suite("endings with no asterisks", () => {
test("should autoclose bold", () => {
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
});
suite("should autoclose, even if the last node isn't a text node", () => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`"),
).toBe("* **Asynchronous Function `async`**");
});
test("opening ** is at the end of the text", () => {
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
"abc **`def` jhk [lmn](opq)**",
);
});
test("if there's a space after the **, it should NOT be autoclosed", () => {
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
"abc \\*\\* `def` jhk [lmn](opq)",
);
});
});
test("should autoclose bold, even if the last node isn't a text node", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
).toBe("* **Asynchronous Function ( `async`**");
});
test("whitespace fakeouts should not be modified", () => {
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
});
// TODO(drifkin): arguably this should just be removed entirely, but empty
// isn't so bad
test("should handle empty bolded items", () => {
expect(processStreamingMarkdown("**")).toBe("");
});
});
suite("partially closed bolded items", () => {
test("simple partial", () => {
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
});
test("partial with non-text node at end", () => {
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
});
test("partial with multiply nested ending nodes", () => {
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
"**abc[abc](`def`)**",
);
});
test("normal emphasis should not be affected", () => {
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
});
test("normal emphasis with nested code should not be affected", () => {
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
});
});
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
});
// skipping for now because this requires partial link completion as well
suite.skip("nested blocks that each need autoclosing", () => {
test("emph nested in link nested in strong nested in list item", () => {
expect(processStreamingMarkdown("* **[abc **def")).toBe(
"* **[abc **def**]()**",
);
});
test("* **[ab *`def`", () => {
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
"* **[ab *`def`*]()**",
);
});
});
});
suite("numbered list items", () => {
test("should remove trailing numbers", () => {
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
});
test("should remove trailing numbers with breaks before", () => {
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
});
test("should remove trailing numbers that form a new paragraph", () => {
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
});
test("but should leave list items separated by two newlines", () => {
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
"1. First\n\n2. S",
);
});
});
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
// through them as I implement
/*
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
describe("Ambiguous endings removal", () => {
it("should remove list markers at the end", () => {
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
});
it("should remove heading markers at the end", () => {
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
});
it("should remove ambiguous bold markers at the end", () => {
expect(processStreamingMarkdown("Text **")).toBe("Text ");
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
});
it("should remove code block markers at the end", () => {
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
expect(processStreamingMarkdown("```")).toBe("");
});
it("should remove single backtick at the end", () => {
expect(processStreamingMarkdown("Text `")).toBe("Text ");
expect(processStreamingMarkdown("`")).toBe("");
});
it("should remove single asterisk at the end", () => {
expect(processStreamingMarkdown("Text *")).toBe("Text ");
expect(processStreamingMarkdown("*")).toBe("");
});
it("should handle empty content", () => {
expect(processStreamingMarkdown("")).toBe("");
});
it("should handle single line removals correctly", () => {
expect(processStreamingMarkdown("* ")).toBe("");
expect(processStreamingMarkdown("# ")).toBe("");
expect(processStreamingMarkdown("**")).toBe("");
expect(processStreamingMarkdown("`")).toBe("");
});
it("shouldn't have this regexp capture group bug", () => {
expect(
processStreamingMarkdown("Here's a shopping list:\n*"),
).not.toContain("0*");
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
"Here's a shopping list:",
);
});
});
describe("List markers", () => {
it("should preserve complete list items", () => {
expect(processStreamingMarkdown("* Complete item")).toBe(
"* Complete item",
);
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
expect(processStreamingMarkdown("1. Numbered item")).toBe(
"1. Numbered item",
);
});
it("should handle indented list markers", () => {
expect(processStreamingMarkdown(" * ")).toBe(" ");
expect(processStreamingMarkdown(" - ")).toBe(" ");
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
});
});
describe("Heading markers", () => {
it("should preserve complete headings", () => {
expect(processStreamingMarkdown("# Complete Heading")).toBe(
"# Complete Heading",
);
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
});
it("should not affect # in other contexts", () => {
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
});
});
describe("Bold text", () => {
it("should close incomplete bold text", () => {
expect(processStreamingMarkdown("This is **bold text")).toBe(
"This is **bold text**",
);
expect(processStreamingMarkdown("Start **bold and more")).toBe(
"Start **bold and more**",
);
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
});
it("should not affect complete bold text", () => {
expect(processStreamingMarkdown("**complete bold**")).toBe(
"**complete bold**",
);
expect(processStreamingMarkdown("Text **bold** more")).toBe(
"Text **bold** more",
);
});
it("should handle nested bold correctly", () => {
expect(processStreamingMarkdown("**bold** and **another")).toBe(
"**bold** and **another**",
);
});
});
describe("Italic text", () => {
it("should close incomplete italic text", () => {
expect(processStreamingMarkdown("This is *italic text")).toBe(
"This is *italic text*",
);
expect(processStreamingMarkdown("Start *italic and more")).toBe(
"Start *italic and more*",
);
});
it("should differentiate between list markers and italic", () => {
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
expect(processStreamingMarkdown("Some *italic text")).toBe(
"Some *italic text*",
);
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
});
it("should not affect complete italic text", () => {
expect(processStreamingMarkdown("*complete italic*")).toBe(
"*complete italic*",
);
expect(processStreamingMarkdown("Text *italic* more")).toBe(
"Text *italic* more",
);
});
});
describe("Code blocks", () => {
it("should close incomplete code blocks", () => {
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
"```javascript\nconst x = 42;\n```",
);
expect(processStreamingMarkdown("```\ncode here")).toBe(
"```\ncode here\n```",
);
});
it("should not affect complete code blocks", () => {
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
"```js\nconst x = 1;\n```",
);
});
it("should handle nested code blocks correctly", () => {
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
"```\ncode\n```\n```python\n```",
);
});
it("should not process markdown inside code blocks", () => {
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
"```\n* not a list\n**not bold**\n```",
);
});
});
describe("Inline code", () => {
it("should close incomplete inline code", () => {
expect(processStreamingMarkdown("This is `inline code")).toBe(
"This is `inline code`",
);
expect(processStreamingMarkdown("Use `console.log")).toBe(
"Use `console.log`",
);
});
it("should not affect complete inline code", () => {
expect(processStreamingMarkdown("`complete code`")).toBe(
"`complete code`",
);
expect(processStreamingMarkdown("Use `code` here")).toBe(
"Use `code` here",
);
});
it("should handle multiple inline codes correctly", () => {
expect(processStreamingMarkdown("`code` and `more")).toBe(
"`code` and `more`",
);
});
it("should not confuse inline code with code blocks", () => {
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
"```\nblock\n```\n`inline`",
);
});
});
describe("Complex streaming scenarios", () => {
it("should handle progressive streaming of a heading", () => {
const steps = [
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
{ input: "# ", expected: "" },
{ input: "# H", expected: "# H" },
{ input: "# Hello", expected: "# Hello" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle progressive streaming of bold text", () => {
const steps = [
{ input: "*", expected: "" },
{ input: "**", expected: "" },
{ input: "**b", expected: "**b**" },
{ input: "**bold", expected: "**bold**" },
{ input: "**bold**", expected: "**bold**" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle multiline content with various patterns", () => {
const multiline = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2
* `;
const expected = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2`;
expect(processStreamingMarkdown(multiline)).toBe(expected);
});
it("should only fix the last line", () => {
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
"# Complete\n# Another",
);
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
"* Item 1\n* Item 2",
);
});
it("should handle mixed content correctly", () => {
const input = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold`;
const expected = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold**`;
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
describe("Edge cases with escaping", () => {
it("should handle escaped asterisks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
"Text \\*not italic*",
);
});
it("should handle escaped backticks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\`not code")).toBe(
"Text \\`not code`",
);
});
});
describe("Code block edge cases", () => {
it("should handle triple backticks in the middle of lines", () => {
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
"Text ``` in middle\n```",
);
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
"```\nText ``` in code\nmore\n```",
);
});
it("should properly close code blocks with language specifiers", () => {
expect(processStreamingMarkdown("```typescript")).toBe(
"```typescript\n```",
);
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
"```typescript\nconst x = 1\n```",
);
});
it("should remove a completely empty partial code block", () => {
expect(processStreamingMarkdown("```\n")).toBe("");
});
});
});
*/

View File

@@ -1,66 +1,123 @@
import React from "react";
import Markdown from "react-markdown";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeRaw from "rehype-raw";
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
import rehypePrismPlus from "rehype-prism-plus";
import rehypeKatex from "rehype-katex";
import remarkStreamingMarkdown, {
type LastNodeInfo,
} from "@/utils/remarkStreamingMarkdown";
import type { PluggableList } from "unified";
import { Streamdown, defaultRemarkPlugins } from "streamdown";
import remarkCitationParser from "@/utils/remarkCitationParser";
import CopyButton from "./CopyButton";
import type { BundledLanguage } from "shiki";
import { highlighter } from "@/lib/highlighter";
interface StreamingMarkdownContentProps {
content: string;
isStreaming?: boolean;
size?: "sm" | "md" | "lg";
onLastNode?: (info: LastNodeInfo) => void;
browserToolResult?: any; // TODO: proper type
}
// Helper to extract text from React nodes
const extractText = (node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
if (React.isValidElement(node)) {
const props = node.props as any;
if (props?.children) {
return extractText(props.children as React.ReactNode);
}
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
};
const CodeBlock = React.memo(
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
const extractText = React.useCallback((node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
// Extract code and language from children
const codeElement = children as React.ReactElement<{
className?: string;
children: React.ReactNode;
}>;
const language =
codeElement.props.className?.replace(/language-/, "") || "";
const codeText = extractText(codeElement.props.children);
if (React.isValidElement(node)) {
if (
node.props &&
typeof node.props === "object" &&
"children" in node.props
) {
return extractText(node.props.children as React.ReactNode);
}
// Synchronously highlight code using the pre-loaded highlighter
const tokens = React.useMemo(() => {
if (!highlighter) return null;
try {
return {
light: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-light" as any,
}),
dark: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-dark" as any,
}),
};
} catch (error) {
console.error("Failed to highlight code:", error);
return null;
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
}, []);
const language = className?.replace(/language-/, "") || "";
}, [codeText, language]);
return (
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
<div className="flex justify-between select-none">
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
<div className="flex select-none">
{language && (
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
)}
<CopyButton
content={extractText(children)}
content={codeText}
showLabels={true}
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
/>
</div>
<pre className={className} {...props}>
{children}
{/* Light mode */}
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.light
? tokens.light.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.light.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
{/* Dark mode */}
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.dark
? tokens.dark.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.dark.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
</div>
);
@@ -68,65 +125,19 @@ const CodeBlock = React.memo(
);
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
React.memo(
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
// Build the remark plugins array
const remarkPlugins = React.useMemo(() => {
const plugins: PluggableList = [
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
remarkCitationParser,
];
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
// Build the remark plugins array - keep default GFM and Math, add citations
const remarkPlugins = React.useMemo(() => {
return [
defaultRemarkPlugins.gfm,
defaultRemarkPlugins.math,
remarkCitationParser,
];
}, []);
// Add streaming plugin when in streaming mode
if (isStreaming) {
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
}
return plugins;
}, [isStreaming, onLastNode]);
// Create a custom sanitization schema that allows math elements
const sanitizeSchema = React.useMemo(() => {
return {
...defaultSchema,
attributes: {
...defaultSchema.attributes,
span: [
...(defaultSchema.attributes?.span || []),
["className", /^katex/],
],
div: [
...(defaultSchema.attributes?.div || []),
["className", /^katex/],
],
"ol-citation": ["cursor", "start", "end"],
},
tagNames: [
...(defaultSchema.tagNames || []),
"math",
"mrow",
"mi",
"mo",
"mn",
"msup",
"msub",
"mfrac",
"mover",
"munder",
"msqrt",
"mroot",
"merror",
"mspace",
"mpadded",
"ol-citation",
],
};
}, []);
return (
<div
className={`
return (
<div
className={`
max-w-full
${size === "sm" ? "prose-sm" : size === "lg" ? "prose-lg" : ""}
prose
@@ -144,7 +155,27 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
prose-pre:my-0
prose-pre:max-w-full
prose-pre:pt-1
[&_code:not(pre_code)]:text-neutral-700
[&_table]:border-collapse
[&_table]:w-full
[&_table]:border
[&_table]:border-neutral-200
[&_table]:rounded-lg
[&_table]:overflow-hidden
[&_th]:px-3
[&_th]:py-2
[&_th]:text-left
[&_th]:font-semibold
[&_th]:border-b
[&_th]:border-r
[&_th]:border-neutral-200
[&_th:last-child]:border-r-0
[&_td]:px-3
[&_td]:py-2
[&_td]:border-r
[&_td]:border-neutral-200
[&_td:last-child]:border-r-0
[&_tbody_tr:not(:last-child)_td]:border-b
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:bg-neutral-100
[&_code:not(pre_code)]:font-normal
[&_code:not(pre_code)]:px-1.5
@@ -160,6 +191,10 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-strong:text-neutral-200
dark:prose-pre:text-neutral-200
dark:prose:pre:text-neutral-200
dark:[&_table]:border-neutral-700
dark:[&_thead]:bg-neutral-800
dark:[&_th]:border-neutral-700
dark:[&_td]:border-neutral-700
dark:[&_code:not(pre_code)]:text-neutral-200
dark:[&_code:not(pre_code)]:bg-neutral-800
dark:[&_code:not(pre_code)]:font-normal
@@ -167,104 +202,86 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-li:marker:text-neutral-300
break-words
`}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<Markdown
remarkPlugins={remarkPlugins}
rehypePlugins={
[
[rehypeRaw, { allowDangerousHtml: true }],
[rehypeSanitize, sanitizeSchema],
[rehypePrismPlus, { ignoreMissing: true }],
[
rehypeKatex,
{
errorColor: "#000000", // Black instead of red for errors
strict: false, // Be more lenient with parsing
throwOnError: false,
},
],
] as PluggableList
}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table {...props}>{children}</table>
</div>
),
// @ts-expect-error: custom type
"ol-citation": ({
cursor,
// start,
// end,
}: {
cursor: number;
start: number;
end: number;
}) => {
// Check if we have a page_stack and if the cursor is valid
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
<Streamdown
parseIncompleteMarkdown={isStreaming}
isAnimating={isStreaming}
remarkPlugins={remarkPlugins}
controls={false}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table
{...props}
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
>
{children}
</table>
</div>
),
// @ts-expect-error: custom citation type
"ol-citation": ({
cursor,
}: {
cursor: number;
start: number;
end: number;
}) => {
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
// Extract a readable title from the URL if possible
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring(
"search_results_".length,
);
return `Search: ${searchTerm}`;
}
// For regular URLs, try to extract domain or use full URL
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
// If not a valid URL, return as is
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
// If we have a valid page URL, wrap in a link
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring("search_results_".length);
return `Search: ${searchTerm}`;
}
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
return url;
}
};
// Otherwise, just return the citation without a link
return citationElement;
},
}}
>
{content}
</Markdown>
</StreamingMarkdownErrorBoundary>
</div>
);
},
);
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
return citationElement;
},
}}
>
{content}
</Streamdown>
</StreamingMarkdownErrorBoundary>
</div>
);
});
interface StreamingMarkdownErrorBoundaryProps {
content: string;

View File

@@ -73,8 +73,9 @@ export default function Thinking({
// Calculate max height for smooth animations
const getMaxHeight = () => {
if (isCollapsed) {
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
return finishedThinking ? "0px" : "12rem";
}
// When expanded, use the content height or grow naturally
return contentHeight ? `${contentHeight}px` : "none";
};
@@ -131,10 +132,11 @@ export default function Thinking({
</div>
<div
ref={wrapperRef}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
style={{
maxHeight: getMaxHeight(),
maxHeight: isCollapsed ? getMaxHeight() : undefined,
opacity: isCollapsed && finishedThinking ? 0 : 1,
}}
>

View File

@@ -16,793 +16,6 @@
--text-color: #ffffff;
}
}
@media (prefers-color-scheme: light) {
.prose {
/**
* One Light theme for prism.js
* Based on Atom's One Light theme: https://github.com/atom/atom/tree/master/packages/one-light-syntax
*/
/**
* One Light colours (accurate as of commit eb064bf on 19 Feb 2021)
* From colors.less
* --mono-1: hsl(230, 8%, 24%);
* --mono-2: hsl(230, 6%, 44%);
* --mono-3: hsl(230, 4%, 64%)
* --hue-1: hsl(198, 99%, 37%);
* --hue-2: hsl(221, 87%, 60%);
* --hue-3: hsl(301, 63%, 40%);
* --hue-4: hsl(119, 34%, 47%);
* --hue-5: hsl(5, 74%, 59%);
* --hue-5-2: hsl(344, 84%, 43%);
* --hue-6: hsl(35, 99%, 36%);
* --hue-6-2: hsl(35, 99%, 40%);
* --syntax-fg: hsl(230, 8%, 24%);
* --syntax-bg: hsl(230, 1%, 98%);
* --syntax-gutter: hsl(230, 1%, 62%);
* --syntax-guide: hsla(230, 8%, 24%, 0.2);
* --syntax-accent: hsl(230, 100%, 66%);
* From syntax-variables.less
* --syntax-selection-color: hsl(230, 1%, 90%);
* --syntax-gutter-background-color-selected: hsl(230, 1%, 90%);
* --syntax-cursor-line: hsla(230, 8%, 24%, 0.05);
*/
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(230, 4%, 64%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(230, 8%, 24%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(35, 99%, 36%);
}
.token.keyword {
color: hsl(301, 63%, 40%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(5, 74%, 59%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(119, 34%, 47%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(221, 87%, 60%);
}
.token.url {
color: hsl(198, 99%, 37%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(230, 8%, 24%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(5, 74%, 59%);
}
.language-css .token.property {
color: hsl(230, 8%, 24%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(198, 99%, 37%);
}
.language-css .token.url > .token.string.url {
color: hsl(119, 34%, 47%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(301, 63%, 40%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(301, 63%, 40%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(344, 84%, 43%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(230, 8%, 24%);
}
.language-json .token.null.keyword {
color: hsl(35, 99%, 36%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(230, 8%, 24%);
}
.language-markdown .token.url > .token.content {
color: hsl(221, 87%, 60%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(198, 99%, 37%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(230, 4%, 64%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(119, 34%, 47%);
}
.language-markdown .token.bold .token.content {
color: hsl(35, 99%, 36%);
}
.language-markdown .token.italic .token.content {
color: hsl(301, 63%, 40%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(5, 74%, 59%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(230, 8%, 24%, 0.2);
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(230, 1%, 90%);
color: hsl(230, 6%, 44%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(230, 1%, 78%); /* custom: darken(--syntax-bg, 20%) */
color: hsl(230, 8%, 24%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(230, 8%, 24%, 0.05);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(230, 1%, 90%);
color: hsl(230, 8%, 24%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(230, 8%, 24%, 0.05);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(230, 8%, 24%, 0.2);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(230, 1%, 62%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(5, 74%, 59%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(119, 34%, 47%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(221, 87%, 60%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(301, 63%, 40%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-light-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(0, 0, 95%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(0, 0, 95%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(0, 0, 95%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(0, 0%, 100%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(230, 8%, 24%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(230, 8%, 24%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
@media (prefers-color-scheme: dark) {
.prose {
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(220, 10%, 40%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(220, 14%, 71%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(29, 54%, 61%);
}
.token.keyword {
color: hsl(286, 60%, 67%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(355, 65%, 65%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(95, 38%, 62%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(207, 82%, 66%);
}
.token.url {
color: hsl(187, 47%, 55%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(220, 14%, 71%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(355, 65%, 65%);
}
.language-css .token.property {
color: hsl(220, 14%, 71%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(187, 47%, 55%);
}
.language-css .token.url > .token.string.url {
color: hsl(95, 38%, 62%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(286, 60%, 67%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(286, 60%, 67%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(5, 48%, 51%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(220, 14%, 71%);
}
.language-json .token.null.keyword {
color: hsl(29, 54%, 61%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(220, 14%, 71%);
}
.language-markdown .token.url > .token.content {
color: hsl(207, 82%, 66%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(187, 47%, 55%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(220, 10%, 40%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(95, 38%, 62%);
}
.language-markdown .token.bold .token.content {
color: hsl(29, 54%, 61%);
}
.language-markdown .token.italic .token.content {
color: hsl(286, 60%, 67%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(355, 65%, 65%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(220, 14%, 71%, 0.15);
text-shadow: none;
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(220, 13%, 26%);
color: hsl(220, 9%, 55%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(220, 13%, 28%);
color: hsl(220, 14%, 71%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(220, 100%, 80%, 0.04);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(220, 13%, 26%);
color: hsl(220, 14%, 71%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(220, 100%, 80%, 0.04);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(220, 14%, 71%, 0.15);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(220, 14%, 45%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(355, 65%, 65%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(95, 38%, 62%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(207, 82%, 66%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(286, 60%, 67%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-dark-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(224, 13%, 17%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(224, 13%, 17%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(224, 13%, 17%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(219, 13%, 22%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(220, 14%, 71%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(220, 14%, 71%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
.prose pre {
contain: layout style;
}
/* Or more aggressively */
.prose pre code {
contain: layout style paint;
}
/* messaging-style typing indicator animation */
@keyframes typing {

View File

@@ -0,0 +1,10 @@
// API configuration
const DEV_API_URL = "http://127.0.0.1:3001";
// Base URL for fetch API calls (can be relative in production)
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
// Full host URL for Ollama client (needs full origin in production)
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;

View File

@@ -0,0 +1,156 @@
import { createHighlighter } from "shiki";
import type { ThemeRegistration } from "shiki";
const oneLightTheme: ThemeRegistration = {
name: "one-light",
type: "light",
colors: {
"editor.background": "#fafafa",
"editor.foreground": "#383a42",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#a0a1a7" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#a626a4" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#50a14f" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#4078f2" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#c18401" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#c18401" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#a626a4" },
},
{ scope: ["punctuation"], settings: { foreground: "#383a42" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e45649", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#c18401", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#a626a4", fontStyle: "italic" },
},
],
};
const oneDarkTheme: ThemeRegistration = {
name: "one-dark",
type: "dark",
colors: {
"editor.background": "#282c34",
"editor.foreground": "#abb2bf",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#5c6370" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#c678dd" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#98c379" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#61afef" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#d19a66" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#d19a66" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#c678dd" },
},
{ scope: ["punctuation"], settings: { foreground: "#abb2bf" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e06c75", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#d19a66", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#c678dd", fontStyle: "italic" },
},
],
};
export let highlighter: Awaited<ReturnType<typeof createHighlighter>> | null =
null;
export const highlighterPromise = createHighlighter({
themes: [oneLightTheme, oneDarkTheme],
langs: [
"javascript",
"typescript",
"python",
"bash",
"shell",
"json",
"html",
"css",
"tsx",
"jsx",
"go",
"rust",
"java",
"c",
"cpp",
"sql",
"yaml",
"markdown",
],
}).then((h) => {
highlighter = h;
return h;
});

View File

@@ -1,4 +1,5 @@
import { Ollama } from "ollama/browser";
import { OLLAMA_HOST } from "./config";
let _ollamaClient: Ollama | null = null;
@@ -6,7 +7,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
get(_target, prop) {
if (!_ollamaClient) {
_ollamaClient = new Ollama({
host: window.location.origin,
host: OLLAMA_HOST,
});
}
const value = _ollamaClient[prop as keyof Ollama];

View File

@@ -0,0 +1,97 @@
import { describe, it, expect } from "vitest";
import { IMAGE_EXTENSIONS, validateFile } from "./fileValidation";
describe("fileValidation", () => {
describe("IMAGE_EXTENSIONS", () => {
it("should include all supported image formats including WebP", () => {
expect(IMAGE_EXTENSIONS).toContain("png");
expect(IMAGE_EXTENSIONS).toContain("jpg");
expect(IMAGE_EXTENSIONS).toContain("jpeg");
expect(IMAGE_EXTENSIONS).toContain("webp");
});
});
describe("validateFile", () => {
const createMockFile = (
name: string,
size: number,
type: string,
): File => {
const blob = new Blob(["test content"], { type });
return new File([blob], name, { type });
};
it("should accept WebP images when vision capability is enabled", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should reject WebP images when vision capability is disabled", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: false,
});
expect(result.valid).toBe(false);
expect(result.error).toBe("This model does not support images");
});
it("should accept PNG images when vision capability is enabled", () => {
const file = createMockFile("test.png", 1024, "image/png");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should accept JPEG images when vision capability is enabled", () => {
const file = createMockFile("test.jpg", 1024, "image/jpeg");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should reject files that are too large", () => {
// Create a file with size property set correctly
const largeSize = 11 * 1024 * 1024; // 11MB
const content = new Uint8Array(largeSize);
const blob = new Blob([content], { type: "image/webp" });
const file = new File([blob], "large.webp", { type: "image/webp" });
const result = validateFile(file, {
hasVisionCapability: true,
maxFileSize: 10, // 10MB limit
});
expect(result.valid).toBe(false);
expect(result.error).toBe("File too large");
});
it("should reject unsupported file types", () => {
const file = createMockFile("test.xyz", 1024, "application/xyz");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(false);
expect(result.error).toBe("File type not supported");
});
it("should respect custom validators", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: true,
customValidator: () => ({
valid: false,
error: "Custom error",
}),
});
expect(result.valid).toBe(false);
expect(result.error).toBe("Custom error");
});
});
// Note: processFiles tests are skipped because FileReader is not available in the Node.js test environment
// These functions are tested in browser environment via integration tests
});

View File

@@ -41,7 +41,7 @@ export const TEXT_FILE_EXTENSIONS = [
"rtf",
];
export const IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"];
export const IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "webp"];
export interface FileValidationOptions {
maxFileSize?: number; // in MB

View File

@@ -1,24 +0,0 @@
import { remark } from "remark";
import remarkStringify from "remark-stringify";
import remarkStreamingMarkdown from "./remarkStreamingMarkdown";
/**
* Process markdown content for streaming display using the remark plugin.
* This is primarily used for testing the remark plugin with string inputs/outputs.
*/
export function processStreamingMarkdown(content: string): string {
if (!content) return content;
const result = remark()
.use(remarkStreamingMarkdown, { debug: false })
.use(remarkStringify)
.processSync(content);
// remove trailing newline to keep tests cleaner
let output = result.toString();
if (output.endsWith("\n")) {
output = output.slice(0, -1);
}
return output;
}

View File

@@ -1,447 +0,0 @@
import { parents, type Proxy } from "unist-util-parents";
import type { Plugin } from "unified";
import type {
Emphasis,
Node,
Parent,
Root,
RootContent,
Text,
Strong,
PhrasingContent,
Paragraph,
} from "mdast";
import { u } from "unist-builder";
declare module "unist" {
interface Node {
/** Added by `unist-util-parents` (or your own walk). */
parent?: Proxy & Parent;
}
}
// interface SimpleTextRule {
// pattern: RegExp;
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
// }
// const simpleTextRules: SimpleTextRule[] = [
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
// {
// pattern: /(\*\*)(?=\S|$)/g,
// transform: (matchesIterator, lastNode) => {
// const textNode = lastNode.node as Text;
// const matches = [...matchesIterator];
// const lastMatch = matches[matches.length - 1];
// const origValue = textNode.value;
// const start = lastMatch.index;
// const sep = lastMatch[1];
// const before = origValue.slice(0, start);
// const after = origValue.slice(start + sep.length);
// if (lastNode.parent) {
// const index = (lastNode.parent.node as Parent).children.indexOf(
// lastNode.node as RootContent,
// );
// const shouldRemove = before.length === 0;
// if (!shouldRemove) {
// textNode.value = before;
// }
// const newNode = u("strong", {
// children: [u("text", { value: after })],
// });
// (lastNode.parent.node as Parent).children.splice(
// index + (shouldRemove ? 0 : 1),
// shouldRemove ? 1 : 0,
// newNode,
// );
// }
// },
// },
// ];
interface Options {
debug?: boolean;
onLastNode?: (info: LastNodeInfo) => void;
}
export interface LastNodeInfo {
path: string[];
type: string;
value?: string;
lastChars?: string;
fullNode: Node;
}
/**
* Removes `child` from `parent` in-place.
* @returns `true` if the child was found and removed; `false` otherwise.
*/
export function removeChildFromParent(
child: RootContent,
parent: Node,
): boolean {
if (!isParent(parent)) return false; // parent isnt a Parent → nothing to do
const idx = parent.children.indexOf(child);
if (idx < 0) return false; // not a child → nothing to remove
parent.children.splice(idx, 1);
return true; // removal successful
}
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
function isParent(node: Node): node is Parent {
// A `Parent` always has a `children` array; make sure it's an array first.
return Array.isArray((node as Partial<Parent>).children);
}
/**
* Follow “last-child” pointers until you reach a leaf.
* Returns the right-most, deepest node in source order.
*/
export function findRightmostDeepestNode(root: Node): Node {
let current: Node = root;
// While the current node *is* a Parent and has at least one child…
while (isParent(current) && current.children.length > 0) {
const lastIndex = current.children.length - 1;
current = current.children[lastIndex];
}
return current; // Leaf: no further children
}
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
return (tree) => {
const treeWithParents = parents(tree);
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
const parentNode = lastNode.parent;
const grandparentNode = parentNode?.parent;
let ruleMatched = false;
// handling `* *` -> ``
//
// if the last node is part of a <list item (otherwise empty)> ->
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
// remove everything up to and including the first list item. This happens
// when we have `* *`, which can become a bolded list item OR a horizontal
// line
if (
lastNode.type === "listItem" &&
parentNode &&
grandparentNode &&
parentNode.type === "list" &&
grandparentNode.type === "listItem" &&
parentNode.children.length === 1 &&
grandparentNode.children.length === 1
) {
ruleMatched = true;
if (grandparentNode.parent) {
removeChildFromParent(
grandparentNode.node as RootContent,
grandparentNode.parent.node,
);
}
// Handle `*` -> ``:
//
// if the last node is just an empty list item, we need to remove it
// because it could become something else (e.g., a horizontal line)
} else if (
lastNode.type === "listItem" &&
parentNode &&
parentNode.type === "list"
) {
ruleMatched = true;
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
} else if (lastNode.type === "thematicBreak") {
ruleMatched = true;
const parent = lastNode.parent;
if (parent) {
removeChildFromParent(lastNode.node as RootContent, parent.node);
}
} else if (lastNode.type === "text") {
const textNode = lastNode.node as Text;
if (textNode.value.endsWith("**")) {
ruleMatched = true;
textNode.value = textNode.value.slice(0, -2);
// if there's a newline then a number, this is very very likely a
// numbered list item. Let's just hide it until the period comes (or
// other text disambiguates it)
} else {
const match = textNode.value.match(/^([0-9]+)$/m);
if (match) {
const number = match[1];
textNode.value = textNode.value.slice(0, -number.length - 1);
ruleMatched = true;
// if the text node is now empty, then we might want to remove other
// elements, like a now-empty containing paragraph, or a break that
// might disappear once more tokens come in
if (textNode.value.length === 0) {
if (
lastNode.parent?.type === "paragraph" &&
lastNode.parent.children.length === 1
) {
// remove the whole paragraph if it's now empty (otherwise it'll
// cause an extra newline that might not last)
removeChildFromParent(
lastNode.parent.node as Paragraph,
lastNode.parent.parent?.node as Node,
);
} else {
const prev = prevSibling(lastNode);
if (prev?.type === "break") {
removeChildFromParent(
prev.node as RootContent,
lastNode.parent?.node as Node,
);
removeChildFromParent(
lastNode.node as RootContent,
lastNode.parent?.node as Node,
);
}
}
}
}
}
}
if (ruleMatched) {
return tree;
}
// we need to
// a case like
// - *def `abc` [abc **def**](abc)*
// is pretty tricky, because if we land just after def, then we actually
// have two separate tags to process at two different parents. Maybe we
// need to keep iterating up until we find a paragraph, but process each
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
// And instead of just processing simple text rules, they need to operate on the whole paragraph
// like `**[abc](def)` needs to become `**[abc](def)**`
// if we're just text at the end, then we should remove some ambiguous characters
if (lastNode.parent) {
const didChange = processParent(lastNode.parent as Parent & Proxy);
if (didChange) {
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
//
// need to regen the tree w/ parents since reparenting could've happened
// treeWithParents = parents(tree);
}
}
const grandparent = lastNode.parent?.parent;
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
// to 2 levels for now until I think more about the stop condition
if (grandparent) {
processParent(grandparent as Parent & Proxy);
}
// console.log("ruleMatched", ruleMatched);
// } else if (lastNode.parent?.type === "paragraph") {
// console.log("!!! paragraph");
// console.log("lastNode.parent", lastNode.parent);
// // Handle `**abc*` -> `**abc**`:
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
// const paragraph = lastNode.parent as Proxy & Paragraph;
// if (paragraph.children.length >= 2) {
// const lastChild = paragraph.children[paragraph.children.length - 1];
// if (lastChild.type === "emphasis") {
// const sibling = paragraph.children[paragraph.children.length - 2];
// if (sibling.type === "text") {
// const siblingText = sibling as Text & Proxy;
// if (siblingText.value.endsWith("*")) {
// ruleMatched = true;
// const textNode = (lastNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// paragraph.node.type = "strong";
// }
// }
// }
// }
// } else if (lastNode.type === "text") {
// // Handle `**abc*` -> `**abc**`:
// //
// // this gets parsed as a text node ending in `*` followed by an emphasis
// // node. So if we're in text, we need to check if our parent is emphasis,
// // and then get our parent's sibling before it and check if it ends with
// // `*`
// const parent = lastNode.parent;
// if (parent && parent.type === "emphasis") {
// const grandparent = parent.parent;
// if (grandparent) {
// const index = (grandparent.node as Parent).children.indexOf(
// parent.node as RootContent,
// );
// if (index > 0) {
// const prevNode = grandparent.children[index - 1];
// if (
// prevNode.type === "text" &&
// (prevNode as Text).value.endsWith("*")
// ) {
// ruleMatched = true;
// const textNode = (prevNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// parent.node.type = "strong";
// }
// }
// }
// }
// if (!ruleMatched) {
// // if the last node is just text, then we process it in order to fix up certain unclosed items
// // e.g., `**abc` -> `**abc**`
// const textNode = lastNode.node as Text;
// for (const rule of simpleTextRules) {
// const matchesIterator = textNode.value.matchAll(rule.pattern);
// const matches = [...matchesIterator];
// if (matches.length > 0) {
// rule.transform(matches, lastNode);
// ruleMatched = true;
// break;
// }
// }
// }
// } else if (!ruleMatched) {
// // console.log("no rule matched", lastNode);
// }
return tree;
};
};
function processParent(parent: Parent & Proxy): boolean {
if (parent.type === "emphasis") {
// Handle `**abc*` -> `**abc**`:
// We detect this when we end with an emphasis node, and it's preceded by
// a text node that ends with `*`
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
// literal in a link), so we probably need to walk up the tree until we
// find an emphasis node or a block? For now we'll just go up one layer to
// catch the most common cases
const emphasisNode = parent as Emphasis & Proxy;
const grandparent = emphasisNode.parent;
if (grandparent) {
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
emphasisNode.node as RootContent,
);
if (indexOfEmphasisNode >= 0) {
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
| (Node & Proxy)
| undefined;
if (nodeBefore?.type === "text") {
const textNode = nodeBefore.node as Text;
if (textNode.value.endsWith("*")) {
const strBefore = textNode.value.slice(0, -1);
textNode.value = strBefore;
const strongNode = u("strong", {
children: emphasisNode.children,
});
(grandparent.node as Parent).children.splice(
indexOfEmphasisNode,
1,
strongNode,
);
return true;
}
}
}
}
}
// Let's check if we have any bold items to close
for (let i = parent.children.length - 1; i >= 0; i--) {
const child = parent.children[i];
if (child.type === "text") {
const textNode = child as Text & Proxy;
const sep = "**";
const index = textNode.value.lastIndexOf(sep);
if (index >= 0) {
let isValidOpening = false;
if (index + sep.length < textNode.value.length) {
const charAfter = textNode.value[index + sep.length];
if (!isWhitespace(charAfter)) {
isValidOpening = true;
}
} else {
if (i < parent.children.length - 1) {
// TODO(drifkin): I'm not sure that this check is strict enough.
// We're trying to detect cases like `**[abc]()` where the char
// after the opening ** is indeed a non-whitespace character. We're
// using the heuristic that there's another item after the current
// one, but I'm not sure if that is good enough. In a well
// constructed tree, there aren't two text nodes in a row, so this
// _seems_ good, but I should think through it more
isValidOpening = true;
}
}
if (isValidOpening) {
// TODO(drifkin): close the bold
const strBefore = textNode.value.slice(0, index);
const strAfter = textNode.value.slice(index + sep.length);
(textNode.node as Text).value = strBefore;
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
const children: PhrasingContent[] = [
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
];
const strongNode: Strong = u("strong", {
children,
});
const nodesAfter = (parent.node as Parent).children.splice(
i + 1,
parent.children.length - i - 1,
strongNode,
);
// TODO(drifkin): this cast seems iffy, should see if we can cast the
// parent instead, which would also help us check some of our
// assumptions
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
return true;
}
}
}
}
return false;
}
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
const parent = node.parent;
if (parent) {
const index = parent.children.indexOf(node);
return parent.children[index - 1] as Node & Proxy;
}
return null;
}
function isWhitespace(str: string) {
return str.trim() === "";
}
// function debugPrintTreeNoPos(tree: Node) {
// console.log(
// JSON.stringify(
// tree,
// (key, value) => {
// if (key === "position") {
// return undefined;
// }
// return value;
// },
// 2,
// ),
// );
// }
export default remarkStreamingMarkdown;

View File

@@ -1705,7 +1705,7 @@ func getStringFromMap(m map[string]any, key, defaultValue string) string {
// isImageAttachment checks if a filename is an image file
func isImageAttachment(filename string) bool {
ext := strings.ToLower(filename)
return strings.HasSuffix(ext, ".png") || strings.HasSuffix(ext, ".jpg") || strings.HasSuffix(ext, ".jpeg")
return strings.HasSuffix(ext, ".png") || strings.HasSuffix(ext, ".jpg") || strings.HasSuffix(ext, ".jpeg") || strings.HasSuffix(ext, ".webp")
}
// ptr is a convenience function for &literal
@@ -1794,13 +1794,14 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
var thinkValue *api.ThinkValue
if think != nil {
// Only set Think if it's actually requesting thinking
if boolValue, ok := think.(bool); ok {
thinkValue = &api.ThinkValue{
Value: boolValue,
if boolValue {
thinkValue = &api.ThinkValue{Value: boolValue}
}
} else if stringValue, ok := think.(string); ok {
thinkValue = &api.ThinkValue{
Value: stringValue,
if stringValue != "" && stringValue != "none" {
thinkValue = &api.ThinkValue{Value: stringValue}
}
}
}

114
cmd/bench/README.md Normal file
View File

@@ -0,0 +1,114 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats
* Detailed performance metrics (prefill, generate, load, total durations)
## Building from Source
```
go build -o ollama-bench bench.go
./bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
```
go run bench.go -model gpt-oss:20b -epochs 3
```
## Usage
### Basic Example
```
./bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Markdown Format
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
```
### CSV Format
Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00
gpt-oss:20b,generate,512,19531.25,51200.00
gpt-oss:20b,load,1,1500000000,0
```
## Metrics Explained
The tool reports four types of metrics for each model:
* prefill: Time spent processing the prompt
* generate: Time spent generating the response
* load: Model loading time (one-time cost)
* total: Total request duration

309
cmd/bench/bench.go Normal file
View File

@@ -0,0 +1,309 @@
package main
import (
"cmp"
"context"
"flag"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
}
type Metrics struct {
Model string
Step string
Count int
Duration time.Duration
}
var once sync.Once
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
if verbose {
printHeader := func() {
fmt.Printf("sysname: %s\n", runtime.GOOS)
fmt.Printf("machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count)
}
} else {
var suffix string
if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
}
}
case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
var tokensPerSec float64
if m.Count > 0 {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
}
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkChat(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
imgData, err = readImage(*fOpt.imageFile)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
return err
}
}
if *fOpt.debug && imgData != nil {
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
return err
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
continue
}
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
continue
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
continue
}
metrics := []Metrics{
{
Model: model,
Step: "prefill",
Count: responseMetrics.PromptEvalCount,
Duration: responseMetrics.PromptEvalDuration,
},
{
Model: model,
Step: "generate",
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "load",
Count: 1,
Duration: responseMetrics.LoadDuration,
},
{
Model: model,
Step: "total",
Count: 1,
Duration: responseMetrics.TotalDuration,
},
}
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return api.ImageData(data), nil
}
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
}
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
}
flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
if len(*fOpt.models) == 0 {
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
flag.Usage()
return
}
BenchmarkChat(fOpt)
}

463
cmd/bench/bench_test.go Normal file
View File

@@ -0,0 +1,463 @@
package main
import (
"bytes"
"crypto/rand"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func createTestFlagOptions() flagOptions {
models := "test-model"
format := "benchstat"
epochs := 1
maxTokens := 100
temperature := 0.7
seed := 42
timeout := 30
prompt := "test prompt"
imageFile := ""
keepAlive := 5.0
verbose := false
debug := false
return flagOptions{
models: &models,
format: &format,
epochs: &epochs,
maxTokens: &maxTokens,
temperature: &temperature,
seed: &seed,
timeout: &timeout,
prompt: &prompt,
imageFile: &imageFile,
keepAlive: &keepAlive,
verbose: &verbose,
debug: &debug,
}
}
func captureOutput(f func()) string {
oldStdout := os.Stdout
oldStderr := os.Stderr
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
}()
r, w, _ := os.Pipe()
os.Stdout = w
os.Stderr = w
f()
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
for _, resp := range responses {
jsonData, err := json.Marshal(resp)
if err != nil {
t.Errorf("Failed to marshal response: %v", err)
return
}
w.Write(jsonData)
w.Write([]byte("\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(10 * time.Millisecond) // Simulate some delay
}
}))
}
func TestBenchmarkChat_Success(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 1",
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 2",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
}
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
t.Errorf("Expected output to contain generate metrics, got: %s", output)
}
if !strings.Contains(output, "ns/token") {
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
}
}
func TestBenchmarkChat_ServerError(t *testing.T) {
fOpt := createTestFlagOptions()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error", http.StatusInternalServerError)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
t.Errorf("Expected error message about chat failure, got: %s", output)
}
}
func TestBenchmarkChat_Timeout(t *testing.T) {
fOpt := createTestFlagOptions()
shortTimeout := 1 // Very short timeout
fOpt.timeout = &shortTimeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate a long delay that will cause timeout
time.Sleep(2 * time.Second)
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Chat request timed out") {
t.Errorf("Expected timeout error message, got: %s", output)
}
}
func TestBenchmarkChat_NoMetrics(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: false, // Never sends Done=true
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "ERROR: No metrics received") {
t.Errorf("Expected no metrics error message, got: %s", output)
}
}
func TestBenchmarkChat_MultipleModels(t *testing.T) {
fOpt := createTestFlagOptions()
models := "model1,model2"
epochs := 2
fOpt.models = &models
fOpt.epochs = &epochs
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
response := api.ChatResponse{
Model: req.Model,
Message: api.Message{
Role: "assistant",
Content: "test response for " + req.Model,
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
// Should be called 4 times (2 models × 2 epochs)
if callCount != 4 {
t.Errorf("Expected 4 API calls, got %d", callCount)
}
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
t.Errorf("Expected output for both models, got: %s", output)
}
}
func TestBenchmarkChat_WithImage(t *testing.T) {
fOpt := createTestFlagOptions()
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
tmpfileName := tmpfile.Name()
fOpt.imageFile = &tmpfileName
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the request contains image data
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
t.Error("Expected request to contain images")
}
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response with image",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
t.Errorf("Expected benchmark output, got: %s", output)
}
}
func TestBenchmarkChat_ImageError(t *testing.T) {
randFileName := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 8
result := make([]byte, length)
rand.Read(result) // Fill with random bytes
for i := range result {
result[i] = charset[result[i]%byte(len(charset))]
}
return string(result) + ".txt"
}
fOpt := createTestFlagOptions()
imageFile := randFileName()
fOpt.imageFile = &imageFile
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err == nil {
t.Error("Expected error from image reading, got nil")
}
})
if !strings.Contains(output, "ERROR: Couldn't read image") {
t.Errorf("Expected image read error message, got: %s", output)
}
}
func TestReadImage_Success(t *testing.T) {
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
imgData, err := readImage(tmpfile.Name())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if imgData == nil {
t.Error("Expected image data, got nil")
}
expected := api.ImageData(content)
if string(imgData) != string(expected) {
t.Errorf("Expected image data %v, got %v", expected, imgData)
}
}
func TestReadImage_FileNotFound(t *testing.T) {
imgData, err := readImage("nonexistentfile.jpg")
if err == nil {
t.Error("Expected error for non-existent file, got nil")
}
if imgData != nil {
t.Error("Expected nil image data for non-existent file")
}
}
func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions()
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if options["num_predict"] != *fOpt.maxTokens {
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
}
if options["temperature"] != *fOpt.temperature {
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
}
if options["seed"] != *fOpt.seed {
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
}
}

View File

@@ -0,0 +1,625 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=4.57.0",
# "jinja2",
# "fastapi",
# "uvicorn",
# "pydantic",
# "requests",
# ]
# ///
"""
Chat Template Testing Tool
Test HuggingFace chat templates against Ollama renderers.
Usage:
# Run predefined test cases against a HuggingFace model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
# Compare HuggingFace output with Ollama renderer
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
# Start server for manual curl testing
uv run cmd/chat_template/chat_template.py --serve
# Show chat template for a model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
"""
import argparse
import json
import sys
from typing import Any
from transformers import AutoTokenizer
TEST_CASES = [
{
"name": "basic_user_message",
"messages": [{"role": "user", "content": "Hello!"}],
"tools": None,
},
{
"name": "with_system_message",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
"tools": None,
},
{
"name": "multi_turn_conversation",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
"tools": None,
},
{
"name": "with_tools",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather?"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "tool_call_and_response",
"messages": [
{"role": "user", "content": "What is the weather in SF?"},
{
"role": "assistant",
"content": "Let me check the weather.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "parallel_tool_calls",
"messages": [
{"role": "user", "content": "Get weather in SF and NYC"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "New York"},
},
},
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
# Thinking tests
{
"name": "assistant_with_thinking",
"messages": [
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": "The answer is 4.",
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
},
{"role": "user", "content": "And 3+3?"},
],
"tools": None,
},
{
"name": "thinking_with_tool_call",
"messages": [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"content": "I'll check the weather for you.",
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Paris"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
{
"name": "thinking_only_no_content",
"messages": [
{"role": "user", "content": "Think about this silently."},
{
"role": "assistant",
"content": "", # HuggingFace requires content field
"thinking": "I'm thinking about this but won't respond with visible content.",
},
{"role": "user", "content": "What did you think?"},
],
"tools": None,
},
]
# Cache for tokenizers
_tokenizer_cache: dict[str, Any] = {}
def get_tokenizer(model_name: str):
"""Get or create tokenizer for the given model."""
if model_name not in _tokenizer_cache:
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return _tokenizer_cache[model_name]
def apply_template(
model: str,
messages: list[dict],
tools: list[dict] | None = None,
) -> str:
"""Apply HuggingFace chat template to messages."""
tokenizer = get_tokenizer(model)
if tools:
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
else:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def get_ollama_prompt(
ollama_model: str,
messages: list[dict],
tools: list[dict] | None = None,
ollama_host: str = "http://localhost:11434",
) -> str | None:
"""Get rendered prompt from Ollama using debug_render_only."""
import requests
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_msg = {"role": msg["role"]}
if "content" in msg:
ollama_msg["content"] = msg["content"]
if "thinking" in msg:
ollama_msg["thinking"] = msg["thinking"]
if "tool_calls" in msg:
# Convert tool_calls to Ollama format
tool_calls = []
for tc in msg["tool_calls"]:
tool_call = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
if "id" in tc:
tool_call["id"] = tc["id"]
tool_calls.append(tool_call)
ollama_msg["tool_calls"] = tool_calls
if "tool_call_id" in msg:
ollama_msg["tool_call_id"] = msg["tool_call_id"]
ollama_messages.append(ollama_msg)
payload = {
"model": ollama_model,
"messages": ollama_messages,
"stream": False,
"_debug_render_only": True,
}
if tools:
payload["tools"] = tools
try:
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
# Field name is _debug_info with underscore prefix
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
return data["_debug_info"]["rendered_template"]
return None
except requests.exceptions.ConnectionError:
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
return None
except Exception as e:
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
return None
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
"""Compute a unified diff between HuggingFace and Ollama prompts."""
import difflib
hf_lines = hf_prompt.splitlines(keepends=True)
ollama_lines = ollama_prompt.splitlines(keepends=True)
diff = difflib.unified_diff(
ollama_lines,
hf_lines,
fromfile="Ollama",
tofile="HuggingFace",
lineterm="",
)
return "".join(diff)
def print_test_output(
name: str,
messages: list[dict],
tools: list[dict] | None,
hf_prompt: str,
ollama_prompt: str | None = None,
as_repr: bool = False,
):
"""Print test output in a format suitable for Go test creation and LLM diffing."""
print(f"\n{'='*60}")
print(f"Test: {name}")
print("=" * 60)
print("\n--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print("\n--- Tools ---")
print(json.dumps(tools, indent=2))
if ollama_prompt is not None:
# Comparison mode
if hf_prompt == ollama_prompt:
print("\n--- Result: MATCH ---")
print("\n--- Prompt (both identical) ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
else:
print("\n--- Result: MISMATCH ---")
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("\n--- Ollama Prompt ---")
if as_repr:
print(repr(ollama_prompt))
else:
print(ollama_prompt)
print("\n--- Diff (Ollama -> HuggingFace) ---")
diff = compute_diff(hf_prompt, ollama_prompt)
if diff:
print(diff)
else:
print("(no line-level diff, check whitespace)")
else:
# HuggingFace only mode
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("=" * 60)
def run_tests(
model: str,
as_repr: bool = False,
test_filter: str | None = None,
ollama_model: str | None = None,
ollama_host: str = "http://localhost:11434",
):
"""Run all predefined test cases against a model."""
if ollama_model:
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
else:
print(f"\nRunning tests against: {model}\n")
matches = 0
mismatches = 0
errors = 0
for test_case in TEST_CASES:
name = test_case["name"]
messages = test_case["messages"]
tools = test_case["tools"]
# Filter tests if specified
if test_filter and test_filter.lower() not in name.lower():
continue
try:
hf_prompt = apply_template(model, messages, tools)
ollama_prompt = None
if ollama_model:
ollama_prompt = get_ollama_prompt(
ollama_model, messages, tools, ollama_host
)
if ollama_prompt is None:
errors += 1
elif hf_prompt == ollama_prompt:
matches += 1
else:
mismatches += 1
print_test_output(
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
)
except Exception as e:
errors += 1
print(f"\n{'='*60}")
print(f"Test: {name} - FAILED")
print(f"--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print(f"--- Tools ---")
print(json.dumps(tools, indent=2))
print(f"--- Error ---")
print(f"{e}")
print("=" * 60)
# Print summary if comparing
if ollama_model:
total = matches + mismatches + errors
print(f"\n{'='*60}")
print("SUMMARY")
print("=" * 60)
print(f" Total: {total}")
print(f" Matches: {matches}")
print(f" Mismatches: {mismatches}")
print(f" Errors: {errors}")
print("=" * 60)
def show_template(model: str):
"""Show the chat template for a model."""
tokenizer = get_tokenizer(model)
print(f"\nChat template for {model}:\n")
print("-" * 60)
print(tokenizer.chat_template)
print("-" * 60)
def start_server(host: str = "0.0.0.0", port: int = 8000):
"""Start the FastAPI server for manual testing."""
from typing import Optional, List, Dict, Any as TypingAny
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
class Message(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
tool_call_id: Optional[str] = None
class GeneratePromptRequest(BaseModel):
messages: List[Message]
model: str = "PrimeIntellect/INTELLECT-3"
tools: Optional[List[Dict[str, TypingAny]]] = None
inject_tools_as_functions: bool = False
class GeneratePromptResponse(BaseModel):
prompt: str
model: str
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
async def generate_prompt(request: GeneratePromptRequest):
try:
messages = []
for msg in request.messages:
message_dict = {"role": msg.role}
if msg.content is not None:
message_dict["content"] = msg.content
if msg.tool_calls is not None:
tool_calls = []
for tc in msg.tool_calls:
tc_copy = tc.copy()
if "function" in tc_copy and "arguments" in tc_copy["function"]:
args = tc_copy["function"]["arguments"]
if isinstance(args, str):
try:
tc_copy["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError:
pass
tool_calls.append(tc_copy)
message_dict["tool_calls"] = tool_calls
if msg.tool_call_id is not None:
message_dict["tool_call_id"] = msg.tool_call_id
messages.append(message_dict)
prompt = apply_template(request.model, messages, request.tools)
return GeneratePromptResponse(prompt=prompt, model=request.model)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
print(f"Starting server on http://{host}:{port}")
print("Endpoints:")
print(" POST /generate-prompt - Generate prompt from messages")
print(" GET /health - Health check")
uvicorn.run(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(
description="HuggingFace Prompt Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--model",
"-m",
type=str,
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
)
parser.add_argument(
"--ollama-model",
"-o",
type=str,
help="Ollama model name to compare against (e.g., qwen3-coder)",
)
parser.add_argument(
"--ollama-host",
type=str,
default="http://localhost:11434",
help="Ollama server URL (default: http://localhost:11434)",
)
parser.add_argument(
"--serve",
"-s",
action="store_true",
help="Start FastAPI server for manual curl testing",
)
parser.add_argument(
"--port",
"-p",
type=int,
default=8000,
help="Server port (default: 8000)",
)
parser.add_argument(
"--show-template",
"-t",
action="store_true",
help="Show the chat template for the model",
)
parser.add_argument(
"--repr",
"-r",
action="store_true",
help="Output prompts as Python repr (shows escape sequences)",
)
parser.add_argument(
"--filter",
"-f",
type=str,
help="Filter tests by name (substring match)",
)
args = parser.parse_args()
if args.serve:
start_server(port=args.port)
elif args.model:
if args.show_template:
show_template(args.model)
else:
run_tests(
args.model,
as_repr=args.repr,
test_filter=args.filter,
ollama_model=args.ollama_model,
ollama_host=args.ollama_host,
)
else:
parser.print_help()
print("\nExample usage:")
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
print(" uv run cmd/chat_template/chat_template.py --serve")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -0,0 +1,136 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name + ".weight",
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "gate", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "up", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
t.name == "v.post_tile_position_embd.weight" ||
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
// these tensors are always F32
return tensorKindFP32
}

View File

@@ -96,7 +96,10 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -2,10 +2,12 @@ package convert
import (
"cmp"
"errors"
"io"
"iter"
"path"
"slices"
"strconv"
"strings"
"github.com/pdevine/tensor"
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched
})
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,

View File

@@ -3,8 +3,10 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"iter"
"math/rand/v2"
"slices"
"strings"
"testing"
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
}
})
}
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}

View File

@@ -2,6 +2,7 @@ package discover
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
@@ -10,12 +11,21 @@ import (
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
func GetCPUMem() (memInfo, error) {
mem, err := getCPUMem()
if err != nil {
return memInfo{}, err
}
return getCPUMemByCgroups(mem), nil
}
func getCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
f, err := os.Open("/proc/meminfo")
@@ -56,6 +66,32 @@ func GetCPUMem() (memInfo, error) {
return mem, nil
}
func getCPUMemByCgroups(mem memInfo) memInfo {
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
if err == nil {
mem.TotalMemory = total
}
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
if err == nil {
mem.FreeMemory = mem.TotalMemory - used
}
return mem
}
func getUint64ValueFromFile(path string) (uint64, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
s := bufio.NewScanner(f)
for s.Scan() {
line := s.Text()
return strconv.ParseUint(line, 10, 64)
}
return 0, errors.New("empty file content")
}
const CpuInfoFilename = "/proc/cpuinfo"
type linuxCpuInfo struct {
@@ -74,7 +110,41 @@ func GetCPUDetails() []CPU {
return nil
}
defer file.Close()
return linuxCPUDetails(file)
cpus := linuxCPUDetails(file)
return overwriteThreadCountByLinuxCgroups(cpus)
}
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
file, err := os.Open("/sys/fs/cgroup/cpu.max")
if err != nil {
return cpus
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if sl := strings.Split(line, " "); len(sl) == 2 {
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
return cpus
}
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU unit micro secs", "error", err)
return cpus
}
threads := int(max(allowdUs/unitUs, 1))
cpu := cpus[0]
cpu.CoreCount = threads
cpu.ThreadCount = threads
return []CPU{cpu}
}
}
return cpus
}
func linuxCPUDetails(file io.Reader) []CPU {

View File

@@ -65,6 +65,11 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
@@ -90,10 +95,16 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
}
dirs = []string{ml.LibOllamaPath, dir}
} else {
@@ -110,7 +121,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
// In the second pass, we more deeply initialize the GPUs to weed out devices that
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
// Only devices that need verification are included in this pass
slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices))
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var wg sync.WaitGroup
@@ -118,11 +129,21 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -446,3 +467,37 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
return devices
}
func overrideWarnings() {
anyFound := false
m := envconfig.AsMap()
for _, k := range []string{
"CUDA_VISIBLE_DEVICES",
"HIP_VISIBLE_DEVICES",
"ROCR_VISIBLE_DEVICES",
"GGML_VK_VISIBLE_DEVICES",
"GPU_DEVICE_ORDINAL",
"HSA_OVERRIDE_GFX_VERSION",
} {
if e, found := m[k]; found && e.Value != "" {
anyFound = true
slog.Warn("user overrode visible devices", k, e.Value)
}
}
if anyFound {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -12,7 +12,7 @@
### Reference
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](./modelfile.md)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
### Resources

View File

@@ -13,9 +13,23 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
## Generate embeddings
Use `/api/embed` with a single string.
<Tabs>
<Tab title="CLI">
Generate embeddings directly from the command line:
```shell
ollama run embeddinggemma "Hello world"
```
You can also pipe text to generate embeddings:
```shell
echo "Hello world" | ollama run embeddinggemma
```
Output is a JSON array.
</Tab>
<Tab title="cURL">
```shell
curl -X POST http://localhost:11434/api/embed \

View File

@@ -9,15 +9,9 @@ sidebarTitle: Cloud
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
Ollama currently supports the following cloud models, with more coming soon:
### Supported models
- `deepseek-v3.1:671b-cloud`
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `kimi-k2:1t-cloud`
- `qwen3-coder:480b-cloud`
- `glm-4.6:cloud`
- `minimax-m2:cloud`
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud).
### Running Cloud models

View File

@@ -68,6 +68,15 @@ To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following c
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
```
## Vulkan Support
Vulkan is bundled into the `ollama/ollama` image.
```shell
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 -e OLLAMA_VULKAN=1 --name ollama ollama/ollama
```
## Run model locally
Now you can run a model:
@@ -79,3 +88,4 @@ docker exec -it ollama ollama run llama3.2
## Try different models
More models can be found on the [Ollama library](https://ollama.com/library).

View File

@@ -63,6 +63,10 @@
{
"source": "/api/openai",
"destination": "/api/openai-compatibility"
},
{
"source": "/api",
"destination": "/api/introduction"
}
],
"navigation": {
@@ -130,7 +134,7 @@
{
"group": "API Reference",
"pages": [
"/api/index",
"/api/introduction",
"/api/authentication",
"/api/streaming",
"/api/usage",

View File

@@ -57,8 +57,13 @@ ollama ps
```
<Info>
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -223,7 +228,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
## How can I use Ollama in Visual Studio Code?
There is already a large collection of plugins available for VSCode as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
There is already a large collection of plugins available for VS Code as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
## How do I use Ollama with GPU acceleration in Docker?
@@ -376,3 +381,13 @@ ollama signin
<Note>
Replace &lt;username&gt; with your actual Windows user name.
</Note>
## How can I stop Ollama from starting when I login to my computer
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
**Windows**
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -3,34 +3,35 @@ title: Hardware support
---
## Nvidia
Ollama supports Nvidia GPUs with compute capability 5.0+.
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
Check your compute compatibility to see if your card is supported:
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
| 8.0 | NVIDIA | `A100` `A30` |
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
| | Tesla | `P40` `P4` |
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
| 5.2 | GeForce GTX | `GTX TITAN X` `GTX 980 Ti` `GTX 980` `GTX 970` `GTX 960` `GTX 950` |
| | Quadro | `M6000 24GB` `M6000` `M5000` `M5500M` `M4000` `M2200` `M2000` `M620` |
| | Tesla | `M60` `M40` |
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
| 8.0 | NVIDIA | `A100` `A30` |
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
| | Tesla | `P40` `P4` |
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
| 5.2 | GeForce GTX | `GTX TITAN X` `GTX 980 Ti` `GTX 980` `GTX 970` `GTX 960` `GTX 950` |
| | Quadro | `M6000 24GB` `M6000` `M5000` `M5500M` `M4000` `M2200` `M2000` `M620` |
| | Tesla | `M60` `M40` |
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
@@ -51,24 +52,28 @@ sudo modprobe nvidia_uvm`
## AMD Radeon
Ollama supports the following AMD GPUs:
Ollama supports the following AMD GPUs via the ROCm library:
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
@@ -90,8 +95,6 @@ At this time, the known supported GPU types on linux are the following LLVM Targ
This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|
| gfx900 | Radeon RX Vega 56 |
| gfx906 | Radeon Instinct MI50 |
| gfx908 | Radeon Instinct MI100 |
| gfx90a | Radeon Instinct MI210 |
| gfx940 | Radeon Instinct MI300 |
@@ -122,6 +125,42 @@ In some Linux distributions, SELinux can prevent containers from
accessing the AMD GPU devices. On the host system you can run
`sudo setsebool container_use_devices=1` to allow containers to use devices.
### Metal (Apple GPUs)
## Metal (Apple GPUs)
Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
bundled with Vulkan support and require no additional setup steps. Most Linux
distributions require installing additional components, and you may have
multiple options for Vulkan drivers between Mesa and GPU Vendor specific packages
- Linux Intel GPU Instructions - https://dgpu-docs.intel.com/driver/client/overview.html
- Linux AMD GPU Instructions - https://amdgpu-install.readthedocs.io/en/latest/install-script.html#specifying-a-vulkan-implementation
For AMD GPUs on some Linux distributions, you may need to add the `ollama` user to the `render` group.
The Ollama scheduler leverages available VRAM data reported by the GPU libraries to
make optimal scheduling decisions. Vulkan requires additional capabilities or
running as root to expose this available VRAM data. If neither root access or this
capability are granted, Ollama will use approximate sizes of the models
to make best effort scheduling decisions.
```bash
sudo setcap cap_perfmon+ep /usr/local/bin/ollama
```
### GPU Selection
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -25,8 +25,23 @@ Install [n8n](https://docs.n8n.io/choose-n8n/).
width="75%"
/>
</div>
3. Confirm Base URL is set to `http://localhost:11434` and click **Save**
<Note> If connecting to `http://localhost:11434` fails, use `http://127.0.0.1:11434`</Note>
3. Confirm Base URL is set to `http://localhost:11434` if running locally or `http://host.docker.internal:11434` if running through docker and click **Save**
<Note>
In environments that don't use Docker Desktop (ie, Linux server installations), `host.docker.internal` is not automatically added.
Run n8n in docker with `--add-host=host.docker.internal:host-gateway`
or add the following to a docker compose file:
```yaml
extra_hosts:
- "host.docker.internal:host-gateway"
```
</Note>
You should see a `Connection tested successfully` message.
4. When creating a new workflow, select **Add a first step** and select an **Ollama node**
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img

View File

@@ -1,34 +1,34 @@
---
title: VS Code
title: VS Code
---
## Install
Install [VSCode](https://code.visualstudio.com/download).
Install [VS Code](https://code.visualstudio.com/download).
## Usage with Ollama
## Usage with Ollama
1. Open Copilot side bar found in top right window
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-sidebar.png"
alt="VSCode chat Sidebar"
width="75%"
/>
</div>
2. Select the model drowpdown > **Manage models**
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-models.png"
alt="VSCode model picker"
width="75%"
/>
</div>
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-sidebar.png"
alt="VS Code chat Sidebar"
width="75%"
/>
</div>
2. Select the model dropdown > **Manage models**
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-models.png"
alt="VS Code model picker"
width="75%"
/>
</div>
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-model-options.png"
alt="VSCode model options dropdown"
width="75%"
/>
</div>
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-model-options.png"
alt="VS Code model options dropdown"
width="75%"
/>
</div>

View File

@@ -149,9 +149,6 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -2,12 +2,15 @@ openapi: 3.1.0
info:
title: Ollama API
version: 0.1.0
license:
name: MIT
url: https://opensource.org/licenses/MIT
description: |
OpenAPI specification for the Ollama HTTP API
servers:
- url: http://localhost:11434
description: Local Ollama instance
description: Ollama
security: []
components:
securitySchemes:
bearerAuth:
@@ -93,8 +96,11 @@ components:
type: boolean
default: true
think:
type: boolean
description: When true, returns separate thinking output in addition to content
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
raw:
type: boolean
description: When true, returns the raw response from the model without any prompt templating
@@ -105,6 +111,12 @@ components:
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
options:
$ref: "#/components/schemas/ModelOptions"
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
GenerateResponse:
type: object
properties:
@@ -144,6 +156,11 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
GenerateStreamEvent:
type: object
properties:
@@ -271,13 +288,22 @@ components:
type: boolean
default: true
think:
type: boolean
description: When true, returns separate thinking output in addition to content
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
keep_alive:
oneOf:
- type: string
- type: number
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
ChatResponse:
type: object
properties:
@@ -310,7 +336,6 @@ components:
type: array
items:
type: string
nullable: true
description: Optional base64-encoded images in the response
done:
type: boolean
@@ -336,6 +361,11 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
ChatStreamEvent:
type: object
properties:
@@ -367,7 +397,6 @@ components:
type: array
items:
type: string
nullable: true
description: Partial base64-encoded images, when present
done:
type: boolean
@@ -543,6 +572,9 @@ components:
license:
type: string
description: The license of the model
modified_at:
type: string
description: Last modified timestamp in ISO 8601 format
details:
type: object
description: High-level model details
@@ -622,6 +654,9 @@ components:
size_vram:
type: integer
description: VRAM usage in bytes
context_length:
type: integer
description: Context length for the running model
PsResponse:
type: object
properties:
@@ -693,6 +728,41 @@ components:
version:
type: string
description: Version of Ollama
TokenLogprob:
type: object
description: Log probability information for a single token alternative
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
Logprob:
type: object
description: Log probability information for a generated token
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
top_logprobs:
type: array
items:
$ref: "#/components/schemas/TokenLogprob"
description: Most likely tokens and their log probabilities at this position
ErrorResponse:
type: object
properties:
@@ -1275,6 +1345,9 @@ paths:
example:
source: gemma3
destination: gemma3-backup
responses:
"200":
description: Model successfully copied
/api/pull:
post:
summary: Pull a model
@@ -1382,16 +1455,7 @@ paths:
model: gemma3
responses:
"200":
description: Deletion status updates.
content:
application/json:
schema:
$ref: "#/components/schemas/StatusResponse"
example:
status: "success"
application/x-ndjson:
schema:
$ref: "#/components/schemas/StatusEvent"
description: Model successfully deleted
/api/version:
get:
summary: Get version

View File

@@ -196,8 +196,6 @@ var (
NoPrune = Bool("OLLAMA_NOPRUNE")
// SchedSpread allows scheduling models across all GPUs.
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
// IntelGPU enables experimental Intel GPU detection.
IntelGPU = Bool("OLLAMA_INTEL_GPU")
// MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
@@ -206,6 +204,8 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
)
func String(s string) func() string {
@@ -314,7 +314,7 @@ func AsMap() map[string]EnvVar {
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
}
return ret

View File

@@ -249,6 +249,9 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}
@@ -797,73 +800,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
return
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {

View File

@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
a.values[i] = e
} else {
discardGGUFString(llm, r)
_ = discardGGUFString(llm, r)
}
}
@@ -568,7 +568,6 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)

1
go.mod
View File

@@ -17,7 +17,6 @@ require (
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sys v0.36.0
)
require (

View File

@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools and optional last message
// Init initializes the handler with tools, optional last message, and think value
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{

View File

@@ -14,6 +14,23 @@ import (
"github.com/ollama/ollama/api"
)
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
t.Helper()
raw := []byte(token)
if len(ints) != len(raw) {
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
return
}
for i, b := range raw {
if ints[i] != int(b) {
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
return
}
}
}
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
@@ -381,3 +398,182 @@ func TestAPIShowModel(t *testing.T) {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIGenerateLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
tests := []struct {
name string
logprobs *bool
topLogprobs int
expectCount int
}{
{
name: "no_logprobs",
logprobs: nil,
topLogprobs: 0,
expectCount: 0,
},
{
name: "logprobs_only",
logprobs: &enableLogprobs,
topLogprobs: 0,
expectCount: 1,
},
{
name: "logprobs_with_top_5",
logprobs: &enableLogprobs,
topLogprobs: 5,
expectCount: 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: "Why is the sky blue?",
Stream: &noStream,
Logprobs: test.logprobs != nil && *test.logprobs,
TopLogprobs: test.topLogprobs,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 10,
},
}
var response api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %s", err)
}
// Check logprobs based on expectation
if test.expectCount == 0 {
if len(response.Logprobs) > 0 {
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
}
} else {
if len(response.Logprobs) == 0 {
t.Errorf("expected logprobs but got none")
}
// Validate each logprob entry
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
// Check top_logprobs if requested
if test.topLogprobs > 0 {
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > test.topLogprobs {
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
}
// Verify top_logprobs are sorted by probability (descending)
for j := 1; j < len(lp.TopLogprobs); j++ {
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
}
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
} else if len(lp.TopLogprobs) > 0 {
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
}
}
}
})
}
}
func TestAPIChatLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{Role: "user", Content: "Say hello in one word"},
},
Stream: &noStream,
Logprobs: enableLogprobs,
TopLogprobs: 3,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 5,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("chat failed: %s", err)
}
if len(response.Logprobs) == 0 {
t.Fatal("expected logprobs in response but got none")
}
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > 3 {
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
}
}

View File

@@ -3,7 +3,6 @@ package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
@@ -40,18 +39,18 @@ type Causal struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch
curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch
curCellRange cellRange
@@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions
c.opts.Except = nil
var locs []int32
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
locs, err = c.findLocs()
if err != nil {
return err
}
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, loc)
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx)
return nil
@@ -257,22 +258,20 @@ func newRange() cellRange {
}
}
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
loc := make([]int32, 0, c.curBatchSize)
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
loc = append(loc, int32(i))
if len(loc) >= c.curBatchSize {
return loc, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
func (c *Causal) updateSlidingWindow() {
@@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
@@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else {
rowSize := c.values[c.curLayer].Stride(2)
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -63,8 +63,13 @@ func BackendInit() {
C.llama_backend_init()
}
func EnumerateGPUs() []ml.DeviceID {
var ids []ml.DeviceID
type Devices struct {
ml.DeviceID
LlamaID uint64
}
func EnumerateGPUs() []Devices {
var ids []Devices
for i := range C.ggml_backend_dev_count() {
device := C.ggml_backend_dev_get(i)
@@ -74,9 +79,12 @@ func EnumerateGPUs() []ml.DeviceID {
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(device, &props)
ids = append(ids, ml.DeviceID{
ID: C.GoString(props.id),
Library: C.GoString(props.library),
ids = append(ids, Devices{
DeviceID: ml.DeviceID{
ID: C.GoString(props.id),
Library: C.GoString(props.library),
},
LlamaID: uint64(i),
})
}
}
@@ -217,7 +225,21 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings
}
// GetLogitsIth gets the logits for the ith token
func (c *Context) GetLogitsIth(i int) []float32 {
logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i)))
if logits == nil {
return nil
}
vocabSize := c.Model().NumVocab()
result := make([]float32, vocabSize)
_ = copy(result, unsafe.Slice((*float32)(logits), vocabSize))
return result
}
type ModelParams struct {
Devices []uint64
NumGpuLayers int
MainGpu int
UseMmap bool
@@ -241,6 +263,21 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.use_mmap = C.bool(params.UseMmap)
cparams.vocab_only = C.bool(params.VocabOnly)
var devices []C.ggml_backend_dev_t
for _, llamaID := range params.Devices {
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
}
if len(devices) > 0 {
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
devicesData := &devices[0]
var devicesPin runtime.Pinner
devicesPin.Pin(devicesData)
defer devicesPin.Unpin()
cparams.devices = devicesData
}
if len(params.TensorSplit) > 0 {
tensorSplitData := &params.TensorSplit[0]

View File

@@ -80,10 +80,10 @@ func TestIssue7978(t *testing.T) {
}
}
func TestSchemaToGrammer(t *testing.T) {
func TestSchemaToGrammar(t *testing.T) {
cases := []struct {
schema string
prefix []byte // nil is check as nil
prefix []byte // nil is checked as nil
}{
{`invalid`, nil},
@@ -92,7 +92,7 @@ func TestSchemaToGrammer(t *testing.T) {
}
for _, c := range cases {
t.Run("x", func(t *testing.T) {
t.Run(c.schema, func(t *testing.T) {
g := SchemaToGrammar([]byte(c.schema))
if c.prefix == nil && g != nil {
t.Fatalf("grammar = %v, want nil", g)

View File

@@ -1,28 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Wed, 30 Jul 2025 08:43:46 -0700
Subject: [PATCH] BF16 macos version guard
Only enable BF16 on supported MacOS versions (v14+)
---
ggml/src/ggml-metal/ggml-metal-context.m | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m
index 052efb7ac..b47dc7879 100644
--- a/ggml/src/ggml-metal/ggml-metal-context.m
+++ b/ggml/src/ggml-metal/ggml-metal-context.m
@@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
- res->use_bfloat = props_dev->has_bfloat;
+ if (@available(macOS 14.0, *)) {
+ res->use_bfloat = props_dev->has_bfloat;
+ } else {
+ res->use_bfloat = false;
+ }
+
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;

View File

@@ -1,25 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Sun, 3 Aug 2025 10:00:20 -0700
Subject: [PATCH] Disable ggml-blas on macos v13 and older
---
ggml/src/ggml-blas/ggml-blas.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
index 88d088952..6a38a51a2 100644
--- a/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
@@ -507,6 +507,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
};
ggml_backend_reg_t ggml_backend_blas_reg(void) {
+ // MacOS prior to v14 does not include cblas_sgemm - disable this backend if it isn't available
+ if (&cblas_sgemm == NULL) {
+ GGML_LOG_INFO("Disabling ggml-blas backend on old MacOS version\n");
+ return NULL;
+ }
static struct ggml_backend_reg ggml_backend_blas_reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_blas_reg_i,

View File

@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++++--
ggml/src/mem_hip.cpp | 452 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++++
9 files changed, 926 insertions(+), 30 deletions(-)
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++--
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 1003 insertions(+), 30 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
@@ -58,7 +58,7 @@ index f9a6587f1..03f359ae9 100644
target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c9333689f..41b00af83 100644
index c9333689f..f1a20e7fe 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -111,7 +111,7 @@ index c9333689f..41b00af83 100644
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
@@ -243,7 +243,7 @@ index 05ff6a5a6..032dee76d 100644
/* .async = */ true,
/* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3a6bbe564..d2c278a35 100644
index 3a6bbe564..ca02ea079 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -229,6 +229,7 @@ class vk_memory_logger;
@@ -337,7 +337,7 @@ index 3a6bbe564..d2c278a35 100644
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
@@ -548,11 +548,12 @@ index 3a6bbe564..d2c278a35 100644
}
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..5a7f5d465
index 000000000..c1949b899
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,452 @@
@@ -0,0 +1,529 @@
+#include "ggml.h"
+#include "ggml-impl.h"
+
+#ifdef _WIN32
+// AMD Device Library eXtra (ADLX)
@@ -570,7 +571,6 @@ index 000000000..5a7f5d465
+// Unused function parameters are commented out to avoid unnecessary type
+// definitions.
+
+#include "ggml-impl.h"
+#include <filesystem>
+#include <mutex>
+
@@ -990,15 +990,92 @@ index 000000000..5a7f5d465
+
+#else // #ifdef _WIN32
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <vector>
+#include <filesystem>
+
+#include <sys/stat.h>
+#include <dirent.h>
+#include <unistd.h>
+#include <glob.h>
+namespace fs = std::filesystem;
+
+extern "C" {
+
+// TODO Linux implementation of accurate VRAM reporting
+int ggml_hip_mgmt_init() {
+ return -1;
+ return 0;
+}
+void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ return -1;
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+
+ glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+
+ for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
+ const char* device_file = glob_result.gl_pathv[i];
+ std::ifstream file(device_file);
+ if (!file.is_open()) {
+ std::cerr << "Failed to open sysfs node" << std::endl;
+ globfree(&glob_result);
+ return 1;
+ }
+
+ std::string line;
+ while (std::getline(file, line)) {
+ // Check for PCI_SLOT_NAME label
+ if (line.find(drmUeventPCISlotLabel) == 0) {
+ std::istringstream iss(line.substr(drmUeventPCISlotLabel.size()));
+ std::string pciSlot;
+ iss >> pciSlot;
+ if (pciSlot == std::string(id)) {
+ std::string dir = fs::path(device_file).parent_path().string();
+
+ std::string totalFile = dir + "/" + drmTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+
+ uint64_t memory;
+ totalFileStream >> memory;
+ *total = memory;
+
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+
+ uint64_t memoryUsed;
+ usedFileStream >> memoryUsed;
+ *free = memory - memoryUsed;
+
+ file.close();
+ globfree(&glob_result);
+ return 0;
+ }
+ }
+ }
+
+ file.close();
+ }
+ GGML_LOG_DEBUG("%s unable to find matching device\n", __func__);
+ globfree(&glob_result);
+ return 1;
+}
+
+} // extern "C"

View File

@@ -38,7 +38,7 @@ index 44ae76d66..639d551a2 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index d2c278a35..221e29509 100644
index ca02ea079..c12b069e5 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -0,0 +1,32 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 03:53:04 -0500
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
(#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index c12b069e5..76c78c2ea 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 08:44:29 -0500
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
(#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax
Based on #16649.
* Add ggml_check_edges
* Add sync logging to show fusion effects
* handle clamp added in #16655
* Update ggml/src/ggml-impl.h
Co-authored-by: Diego Devesa <slarengh@gmail.com>
---
ggml/src/ggml-impl.h | 16 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
3 files changed, 272 insertions(+), 138 deletions(-)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 639d551a2..e5c446d1d 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
+#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
+ int start_idx,
+ std::initializer_list<std::array<int, 3>> edges) {
+ for (const auto & edge : edges) {
+ int dst_node = edge[0];
+ int src_idx = edge[1];
+ int src_node = edge[2];
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+ return false;
+ }
+ }
+ return true;
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 7669ed206..63a762ec2 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
+ GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
+ { 8, 0, 5 }, // div->src[0] == reshape
+ { 8, 1, 7 }, // div->src[1] == clamp
+ { 9, 0, 8 }, // reshape->src[0] == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+};
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+ { 1, 0, 0 }, // view->src[0] == argsort
+ { 2, 1, 1 }, // get_rows->src[1] == view
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
+};
+
+enum topk_moe_mode {
+ TOPK_MOE_EARLY_SOFTMAX,
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
+ TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
+ TOPK_MOE_LATE_SOFTMAX;
+ return mode;
+}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -607,8 +671,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
- // [2] is {!norm, norm}
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
- return ctx->device->pipeline_topk_moe[idx][with_norm];
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
+ if (ctx->num_additional_fused_ops) {
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
+ }
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
+ cgraph->nodes[node_idx + 5];
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
- vk_op_topk_moe_push_constants pc;
+ vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+
+#define ENABLE_SYNC_LOGGING 0
+
if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+ std::cerr << "sync" << std::endl;
+#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+#if ENABLE_SYNC_LOGGING
+ if (!dryrun) {
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ auto *n = cgraph->nodes[node_idx + i];
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
+ if (n->op == GGML_OP_GLU) {
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+ }
+ std::cerr << std::endl;
+ }
+ }
+#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ if (ctx->num_additional_fused_ops) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+ } else {
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ }
break;
case GGML_OP_SUM:
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
- int node_idx, bool with_norm) {
+ int node_idx, topk_moe_mode mode) {
- if (with_norm) {
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
- return false;
- }
- }
- } else {
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
- return false;
- }
- }
- }
+ const ggml_tensor * softmax;
+ const ggml_tensor * weights;
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+ switch (mode) {
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 9];
+ break;
+ case TOPK_MOE_EARLY_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 4];
+ break;
+ case TOPK_MOE_LATE_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 4];
+ weights = cgraph->nodes[node_idx + 5];
+ break;
+ default:
+ return false;
+ }
const float * op_params = (const float *)softmax->op_params;
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
- // Check that the nodes don't have any unexpected uses
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
- // softmax is used by reshape and argsort
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
- reshape1->src[0] != softmax ||
- argsort->src[0] != softmax) {
- return false;
- }
- // reshape is used by get_rows
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
- get_rows->src[0] != reshape1) {
- return false;
- }
- // argsort is used by view
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
- view->src[0] != argsort) {
- return false;
- }
- // view is written (via argsort), we can skip checking it
-
- if (with_norm) {
- // get_rows is used by reshape
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
- reshape5->src[0] != get_rows) {
- return false;
- }
-
- // reshape is used by sum_rows and div
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
- sum_rows->src[0] != reshape5 ||
- div->src[0] != reshape5) {
- return false;
- }
-
- // sum_rows is used by div
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
- div->src[1] != sum_rows) {
- return false;
- }
-
- // div/reshape are written
- if (reshape8->src[0] != div) {
- return false;
- }
- }
-
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
- // Avoid reordering topk_moe_norm
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
- bool is_topk_moe_norm = true;
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
- is_topk_moe_norm = false;
+ // Check for fusion patterns and avoid reordering them
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+ if (start + (int)pattern.size() <= graph->n_nodes) {
+ bool is_pattern = true;
+ for (size_t j = 0; j < pattern.size(); ++j) {
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+ is_pattern = false;
+ }
}
+ return is_pattern;
}
- if (is_topk_moe_norm) {
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+ return false;
+ };
+
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+ if (match_pattern(pattern, first_unused)) {
+ for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
- continue;
+ return true;
}
+ return false;
+ };
+
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
+ continue;
+ }
+ if (keep_pattern(topk_moe_early_softmax)) {
+ continue;
}
+ if (keep_pattern(topk_moe_late_softmax)) {
+ continue;
+ }
+
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
+ // Don't pull forward nodes from fusion patterns
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_early_softmax, j) ||
+ match_pattern(topk_moe_late_softmax, j)) {
+ continue;
+ }
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index 9e56d5f8a..bc1c278bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
-void main() {
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
- if (row >= n_rows) {
- return;
- }
+const float INFINITY = 1.0 / 0.0;
- const uint logits_offset = n_experts * row;
- const uint weights_offset = n_expert_used * row;
- const uint ids_offset = n_experts * row;
-
- float logits_r[experts_per_thread];
-
- const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
[[unroll]]
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
- const uint expert = i + gl_LocalInvocationID.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
}
- float max_val = logits_r[0];
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
[[unroll]]
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
}
- max_val = subgroupMax(max_val);
+ sum = subgroupAdd(sum);
- float wt[experts_per_thread];
- float tmp = 0.f;
+ const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = exp(val - max_val);
- tmp += wt[i];
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
}
+}
- tmp = subgroupAdd(tmp);
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+ if (row >= n_rows) {
+ return;
+ }
- const float inv_sum = 1.0f / tmp;
+ const uint logits_offset = n_experts * row;
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+
+ float wt[experts_per_thread];
[[unroll]]
- for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + gl_LocalInvocationID.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (!late_softmax) {
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
+ if (late_softmax) {
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+ }
+
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index db92a7901..e959674d1 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}

View File

@@ -0,0 +1,77 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Fri, 31 Oct 2025 08:14:49 +0100
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader
* metal : fix mul_mm_id
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
tests/test-backend-ops.cpp | 3 +++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 758116342..c78082ac3 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 8b238ac4b..d955b4fc7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
+#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
index 72fec4404..1c0f5306f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
-#define BK_STEP 1
+// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 657b6cc2f..1f8dda383 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ // gpt-oss issue with Vulkan mmq_id
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {

View File

@@ -0,0 +1,80 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Masato Nakasaka <masato.nakasaka@intel.com>
Date: Fri, 31 Oct 2025 16:18:59 +0900
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
supported (#16796)
* Experimenting crash fix
* added assert for aborting and fixed comment
* changed to check if a pipeline is empty or not
* Moved function in class definition
* replaced with is_empty
* Modified is_empty to check only unaligned pipelines
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index e959674d1..903050b0b 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
+ // Returns true when all unaligned pipelines are null.
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
+ // while aligned pipelines are optional
+ bool is_empty() const {
+ return l == nullptr && m == nullptr && s == nullptr;
+ }
};
-
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
+ bool support_fp16acc = !mmp.f16acc->is_empty();
+ bool support_fp32acc = !mmp.f32acc->is_empty();
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
+ return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
+ return mmp.f32acc;
}
}

View File

@@ -0,0 +1,25 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Tue, 18 Nov 2025 11:13:04 -0800
Subject: [PATCH] ggml-cuda: skip large batches
cuda panics on batches larger than 1024 so mark it as unsupported to
fallback to cpu
---
ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f1a20e7fe..1a71e07c9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
+ if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
+ return false;
+ }
#ifdef GGML_USE_MUSA
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@@ -0,0 +1,28 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Tue, 18 Nov 2025 09:58:23 -0800
Subject: [PATCH] win: exit instead of abort
---
ggml/src/ggml.c | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 9be35c1be..923c33d05 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
fprintf(stderr, "%s\n", message);
ggml_print_backtrace();
}
-
+#if defined(_WIN32)
+ fflush(stderr);
+ fflush(stdout);
+ exit(1);
+#else
abort();
+#endif
}
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp

View File

@@ -1,516 +0,0 @@
package llm
import (
"fmt"
"log/slog"
"os"
"slices"
"sort"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
)
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
// The list of GPUs returned will always be the same brand (library)
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
for _, gl := range ml.ByLibrary(gpus) {
sgl := append(make([]ml.DeviceInfo, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
sort.Sort(sort.Reverse(ml.ByFreeMemory(sgl)))
if !envconfig.SchedSpread() {
// Try to pack into as few GPUs as possible, starting from 1 GPU
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
gpuSubset := sgl[:numGPUs]
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
if ok {
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
"model", modelPath,
"library", sgl[0].Library,
"parallel", numParallel,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", numGPUs)
return gpuSubset
}
}
} else {
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
slog.Info("new model will fit in available VRAM, loading",
"model", modelPath,
"library", sgl[0].Library,
"parallel", numParallel,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", len(sgl))
return sgl
}
}
}
return nil
}
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
byLibrary := ml.ByLibrary(gpus)
if len(byLibrary) <= 1 {
return gpus
}
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
}
}
return byLibrary[bestFit]
}
// This algorithm looks for a complete fit to determine if we need to unload other models
func predictServerFit(allGpus []ml.DeviceInfo, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them
var estimatedVRAM uint64
for _, gpus := range ml.ByLibrary(allGpus) {
var layerCount int
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
func verifyCPUFit(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, systemInfo ml.SystemInfo, numParallel int) bool {
estimate := estimateGPULayers(nil, f, projectors, opts, numParallel)
if estimate.TotalSize > systemInfo.FreeMemory {
return false
}
slog.Info("new model will fit in available system memory for CPU inference, loading",
"model", modelPath,
"parallel", numParallel,
"required", format.HumanBytes2(estimate.TotalSize),
)
return true
}
type MemoryEstimate struct {
// How many layers we predict we can load
Layers int
// The size of the graph which occupies the main GPU
Graph uint64
// How much VRAM will be allocated given the number of layers we predict
VRAMSize uint64
// The total size of the model if loaded into VRAM. If all layers are loaded, VRAMSize == TotalSize
TotalSize uint64
// For multi-GPU scenarios, this provides the tensor split parameter
TensorSplit []int
// For multi-GPU scenarios, this is the size in bytes per GPU
GPUSizes []uint64
// internal fields for logging purposes
inferenceLibrary string
layersRequested int
layersModel int
availableList []string
kv uint64
allocationsList []string
memoryWeights uint64
memoryLayerOutput uint64
graphFullOffload uint64
graphPartialOffload uint64
projectorWeights, projectorGraph uint64
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
// Graph size for a partial offload, applies to all GPUs
var graphPartialOffload uint64
// Graph size when all layers are offloaded, applies to all GPUs
var graphFullOffload uint64
// Final graph offload once we know full or partial
var graphOffload uint64
// Projectors loaded into GPU0 only
var llamaEngineProjectorWeights uint64
// Projectors loaded with output layer
var ollamaEngineProjectorWeights uint64
var ollamaEngineProjectorGraph uint64
// Conditional output size on GPU 0
var memoryLayerOutput uint64
// The sizes of a layer
var layerSize uint64
// The sum of all the layer sizes (just for logging)
var memoryWeights uint64
// True if all the layers are loaded
var fullyLoaded bool
// Overflow that didn't fit into the GPU
var overflow uint64
overhead := envconfig.GpuOverhead()
availableList := make([]string, len(gpus))
libraries := []string{}
for i, gpu := range gpus {
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
if !slices.Contains(libraries, gpu.Library) {
libraries = append(libraries, gpu.Library)
}
}
if len(libraries) == 0 {
libraries = []string{"cpu"}
}
slog.Debug("evaluating", "library", strings.Join(libraries, ","), "gpu_count", len(gpus), "available", availableList)
for _, projector := range projectors {
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
}
if llamaEngineProjectorWeights == 0 {
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer
if blk0, ok := layers["blk.0"]; ok {
layerSize = blk0.Size()
} else {
slog.Warn("model missing blk.0 layer size")
}
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
ml.FlashAttentionSupported(gpus) &&
f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType())
if f.SupportsKVCacheType(requested) {
kvct = requested
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
if len(kv) > 0 {
layerSize += kv[0]
}
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 {
headsKV := f.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := f.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
// on metal there's no partial offload overhead
if len(gpus) > 0 && gpus[0].Library == "Metal" {
graphPartialOffload = graphFullOffload
} else if len(gpus) > 1 {
// multigpu should always use the partial graph size
graphFullOffload = graphPartialOffload
}
// Output layer handled at the end if we have space
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.Size()
}
if layer, ok := layers["output"]; ok {
memoryLayerOutput += layer.Size()
} else if layer, ok := layers["token_embd"]; ok {
memoryLayerOutput += layer.Size()
}
gpuZeroOverhead := llamaEngineProjectorWeights
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
var layerCount int
tensorSplit := make([]int, len(gpus))
gpuAllocations := make([]uint64, len(gpus))
type gs struct {
i int
g *ml.DeviceInfo
}
gpusWithSpace := []gs{}
for i := range gpus {
var gzo uint64
if len(gpusWithSpace) == 0 {
gzo = gpuZeroOverhead
}
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory()+2*layerSize {
slog.Debug("gpu has too little memory to allocate any layers",
"id", gpus[i].ID,
"library", gpus[i].Library,
"compute", gpus[i].Compute(),
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
"name", gpus[i].Name,
"total", format.HumanBytes2(gpus[i].TotalMemory),
"available", format.HumanBytes2(gpus[i].FreeMemory),
"minimum_memory", gpus[i].MinimumMemory,
"layer_size", format.HumanBytes2(layerSize),
"gpu_zer_overhead", format.HumanBytes2(gzo),
"partial_offload", format.HumanBytes2(graphPartialOffload),
"full_offload", format.HumanBytes2(graphFullOffload),
)
continue
}
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
gpuAllocations[i] += gpus[i].MinimumMemory() + layerSize // We hold off on graph until we know partial vs. full
}
var gpuZeroID int
if len(gpusWithSpace) > 0 {
gpuZeroID = gpusWithSpace[0].i
gpuAllocations[gpuZeroID] += gpuZeroOverhead
} else {
overflow += gpuZeroOverhead
}
// For all the layers, find where they can fit on the GPU(s)
for i := int(f.KV().BlockCount()) - 1; i >= 0; i-- {
// Some models have inconsistent layer sizes
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv[i]
memoryWeights += blk.Size()
}
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
overflow += layerSize
continue
}
// distribute the layers across the GPU(s) that have space
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[i%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+layerSize {
gpuAllocations[g.i] += layerSize
tensorSplit[g.i]++
layerCount++
break
} else {
gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
}
}
if len(gpusWithSpace) == 0 {
overflow += layerSize
}
}
if layerCount >= int(f.KV().BlockCount()) {
fullyLoaded = true
}
// Determine if we need to consider output then find where it fits
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
if memoryLastLayer > 0 {
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[layerCount%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+memoryLastLayer {
gpuAllocations[g.i] += memoryLastLayer
tensorSplit[g.i]++
layerCount++
break
}
}
}
if layerCount < int(f.KV().BlockCount())+1 {
fullyLoaded = false
overflow += memoryLastLayer
}
}
// Add the applicable (full or partial) graph allocations
for i := range gpus {
if tensorSplit[i] <= 0 {
continue
}
if fullyLoaded {
gpuAllocations[i] += graphFullOffload
} else {
gpuAllocations[i] += graphPartialOffload
}
}
if fullyLoaded {
graphOffload = graphFullOffload
} else {
graphOffload = graphPartialOffload
}
// Summaries for the log
var memoryRequiredPartial, memoryRequiredTotal uint64
for i := range gpuAllocations {
memoryRequiredPartial += gpuAllocations[i]
}
memoryRequiredTotal = memoryRequiredPartial + overflow
allocationsList := []string{}
for _, a := range gpuAllocations {
allocationsList = append(allocationsList, format.HumanBytes2(a))
}
estimate := MemoryEstimate{
TotalSize: memoryRequiredTotal,
Layers: 0,
Graph: 0,
VRAMSize: 0,
GPUSizes: []uint64{},
inferenceLibrary: strings.Join(libraries, ","),
layersRequested: opts.NumGPU,
layersModel: int(f.KV().BlockCount()) + 1,
availableList: availableList,
kv: kvTotal,
allocationsList: allocationsList,
memoryWeights: memoryWeights,
memoryLayerOutput: memoryLayerOutput,
graphFullOffload: graphFullOffload,
graphPartialOffload: graphPartialOffload,
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
projectorGraph: ollamaEngineProjectorGraph,
}
if len(gpus) == 0 {
return estimate
}
if layerCount == 0 {
slog.Debug("insufficient VRAM to load any model layers")
return estimate
}
estimate.Layers = layerCount
estimate.Graph = graphOffload
estimate.VRAMSize = memoryRequiredPartial
estimate.TotalSize = memoryRequiredTotal
estimate.TensorSplit = tensorSplit
estimate.GPUSizes = gpuAllocations
return estimate
}
func (m MemoryEstimate) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("library", m.inferenceLibrary),
slog.Group(
"layers",
// requested number of layers to offload
"requested", m.layersRequested,
// The number of layers the model has (including output)
"model", m.layersModel,
// estimated number of layers that can be offloaded
"offload", m.Layers,
// multi-gpu split for tensors
"split", m.TensorSplit,
),
slog.Group(
"memory",
// memory available by GPU for offloading
"available", m.availableList,
"gpu_overhead", format.HumanBytes2(envconfig.GpuOverhead()),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(m.TotalSize),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(m.VRAMSize),
// memory of KV cache
"kv", format.HumanBytes2(m.kv),
// Allocations across the GPUs
"allocations", m.allocationsList,
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(m.graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(m.graphPartialOffload),
),
),
}
if m.projectorWeights > 0 {
attrs = append(attrs, slog.Group(
"projector",
"weights", format.HumanBytes2(m.projectorWeights),
"graph", format.HumanBytes2(m.projectorGraph),
))
}
return slog.GroupValue(attrs...)
}
func projectorMemoryRequirements(filename string) (weights uint64) {
file, err := os.Open(filename)
if err != nil {
return 0
}
defer file.Close()
ggml, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}
for _, layer := range ggml.Tensors().GroupLayers() {
weights += layer.Size()
}
return weights
}

View File

@@ -1,130 +0,0 @@
package llm
import (
"bytes"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
)
func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048")
modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err)
defer f.Close()
inputLayerCount := 5
tensors := []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}
assert.Len(t, tensors, inputLayerCount+1)
err = ggml.WriteGGUF(f, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(inputLayerCount),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, tensors)
require.NoError(t, err)
ggml, err := LoadModel(f.Name(), 0)
if err != nil {
t.Fatal(err)
}
// Simple CPU scenario
gpus := []ml.DeviceInfo{}
projectors := []string{}
opts := api.DefaultOptions()
t.Run("cpu", func(t *testing.T) {
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, 0, estimate.Layers)
assert.Equal(t, uint64(0), estimate.Graph)
})
// derived from the dummy ggml file above
graphPartialOffload := uint64(202377216)
graphFullOffload := uint64(171968512)
layerSize := uint64(33554436)
projectorSize := uint64(0)
memoryLayerOutput := uint64(4)
// Dual CUDA scenario with asymmetry
gpuMinimumMemory := uint64(457 * format.MebiByte)
gpus = []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{
Library: "CUDA",
},
},
{
DeviceID: ml.DeviceID{
Library: "CUDA",
},
},
}
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
for i, s := range []struct {
layer0, layer1 uint64
expect0, expect1 int
}{
{1, 1, 1, 1},
{2, 1, 2, 1},
{2, 2, 2, 2},
{1, 2, 1, 2},
{3, 3, 3, 3},
{4, 4, 3, 3},
{6, 6, 3, 3},
{0, 3, 0, 3},
} {
t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
gpus[0].FreeMemory = 0
gpus[1].FreeMemory = 0
gpus[0].FreeMemory += projectorSize
if s.layer0 > 0 {
gpus[0].FreeMemory += memoryLayerOutput
} else {
gpus[1].FreeMemory += memoryLayerOutput
}
gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s)
assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s)
var layerSums uint64
for _, b := range estimate.GPUSizes {
layerSums += b
}
if estimate.Layers < inputLayerCount+1 {
assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
} else {
assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
}
})
}
}

View File

@@ -84,25 +84,21 @@ type LlamaServer interface {
// llmServer is an instance of a runner hosting a single model
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
numParallel int
modelPath string
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
modelPath string
loadRequest LoadRequest // Parameters used to initialize the runner
loadRequest LoadRequest // Parameters used to initialize the runner
mem *ml.BackendMemory // Memory allocations for this model
// llamaModel is an instance of the cgo llama.cpp model definition
// nil if this server is running the new engine
llamaModel *llama.Model
llamaModelLock *sync.Mutex
// textProcessor handles text encoding/decoding for the model in the Ollama engine
// nil if this server is running the llama.cpp based engine
textProcessor model.TextProcessor
totalLayers uint64
loadStart time.Time // Record how long it took the model to load
loadProgress float32
@@ -113,15 +109,13 @@ type llmServer struct {
type llamaServer struct {
llmServer
ggml *ggml.GGML
gpus []ml.DeviceInfo // The set of GPUs covered by the memory estimate
estimate MemoryEstimate
ggml *ggml.GGML
}
type ollamaServer struct {
llmServer
mem *ml.BackendMemory
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -245,8 +239,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
loadRequest: loadRequest,
llamaModel: llamaModel,
llamaModelLock: &sync.Mutex{},
textProcessor: textProcessor,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: f.KV().BlockCount() + 1,
loadStart: time.Now(),
@@ -281,7 +273,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s}, nil
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -463,169 +455,226 @@ type LoadResponse struct {
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
systemTotalMemory := systemInfo.TotalMemory
systemFreeMemory := systemInfo.FreeMemory
systemSwapFreeMemory := systemInfo.FreeSwap
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
if len(gpus) == 0 || s.options.NumGPU == 0 {
if !verifyCPUFit(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, systemInfo, s.numParallel) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return nil, fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
// Synthesize memory allocation information based on our estimates
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
Name: "CPU",
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
for i := range s.mem.GPUs {
s.mem.GPUs[i].Name = gpus[i].Name
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
// Use the size of one layer as a buffer
layers := s.ggml.Tensors().GroupLayers()
if blk0, ok := layers["blk.0"]; ok {
for i := range gpus {
gpus[i].FreeMemory -= blk0.Size() + kv[0]
}
} else {
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
if g == nil {
if !requireFull {
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
} else {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return nil, ErrLoadRequiredFull
slog.Warn("model missing blk.0 layer size")
}
// Assign all the layers to the CPU for now, they will get reassigned later
for i := range s.ggml.KV().BlockCount() {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
s.mem.CPU.Weights[i] = blk.Size()
s.mem.CPU.Cache[i] += kv[i]
}
}
// We historically haven't included InputWeights in the model size
var outputWeights uint64
if layer, ok := layers["output_norm"]; ok {
outputWeights += layer.Size()
}
if layer, ok := layers["output"]; ok {
outputWeights += layer.Size()
} else if layer, ok := layers["token_embd"]; ok {
outputWeights += layer.Size()
}
s.mem.CPU.Weights[s.totalLayers-1] = outputWeights
// The vision projector is always loaded on the first GPU if available.
// This can't be assigned by us, so just subtract it from free space
projectorGPU := -1
var projectorWeights uint64
if len(gpus) > 0 {
for _, projector := range s.loadRequest.LoraPath {
projectorWeights += projectorMemoryRequirements(projector)
}
// llama.cpp uses the first discrete GPU if available, otherwise the first iGPU
firstIntegrated := -1
for i := range gpus {
if !gpus[i].Integrated {
projectorGPU = i
break
}
if firstIntegrated == -1 {
firstIntegrated = i
}
}
gpus = g
}
s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel)
if len(gpus) >= 1 {
switch {
case s.options.NumGPU == 0:
gpus = []ml.DeviceInfo{}
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.TotalMemory:
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
s.options.NumGPU = 0
gpus = []ml.DeviceInfo{}
case gpus[0].Library != "Metal" && s.estimate.Layers == 0:
// Don't bother loading into the GPU if no layers can fit
gpus = []ml.DeviceInfo{}
case s.options.NumGPU < 0 && s.estimate.Layers > 0:
s.options.NumGPU = s.estimate.Layers
if projectorGPU == -1 {
projectorGPU = firstIntegrated
}
} else {
s.options.NumGPU = 0
gpus[projectorGPU].FreeMemory -= projectorWeights
}
// On linux and windows, over-allocating CPU memory will almost always result in an error
// Darwin has fully dynamic swap so has no direct concept of free swap space
if runtime.GOOS != "darwin" {
systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize
available := systemInfo.FreeMemory + systemInfo.FreeSwap
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 {
headsKV := s.ggml.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := s.ggml.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
// On Metal there's no partial offload overhead
if len(gpus) > 0 && gpus[0].Library == "Metal" {
graphPartialOffload = graphFullOffload
}
// Create a layout based on the memory data that we've built. The compute graph
// for GPUs is iteratively assigned based on the number of GPUs that are required.
var gpuLayers ml.GPULayersList
for {
prevGPULayers := gpuLayers
var err error
gpuLayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, 0)
if err != nil {
return nil, err
}
if len(gpuLayers) > len(prevGPULayers) {
for _, gl := range gpuLayers {
for i := range s.mem.GPUs {
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
s.mem.GPUs[i].Graph = max(graphPartialOffload, graphFullOffload)
break
}
}
}
} else {
break
}
}
slog.Info("offload", "", s.estimate)
// This maintains the historical assignment of graph sizes, though it isn't fully accurate
graphSize := graphFullOffload
if gpuLayers.Sum() < int(s.totalLayers) {
graphSize = graphPartialOffload
}
s.gpus = gpus
s.loadRequest.GPULayers = createGPULayers(s.estimate, s.ggml, gpus, s.options.NumGPU)
// For all layers that we have assigned to GPUs, move them in the memory data so
// that it is reported accurately
for _, gl := range gpuLayers {
for i := range s.mem.GPUs {
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
for _, l := range gl.Layers {
s.mem.GPUs[i].Weights[l] = s.mem.CPU.Weights[l]
s.mem.GPUs[i].Cache[l] = s.mem.CPU.Cache[l]
// Mmap is only supported on the llama engine
if s.textProcessor == nil {
s.loadRequest.UseMmap = true
s.mem.CPU.Weights[l] = 0
s.mem.CPU.Cache[l] = 0
}
// mmap has issues with partial offloading on metal
for _, g := range gpus {
if g.Library == "Metal" &&
uint64(s.options.NumGPU) > 0 &&
uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 {
s.options.UseMMap = new(bool)
*s.options.UseMMap = false
s.mem.GPUs[i].Graph = graphSize
break
}
}
}
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
if projectorGPU > 0 && len(s.mem.GPUs[projectorGPU].Weights) > 0 {
s.mem.GPUs[projectorGPU].Weights[s.totalLayers-1] += projectorWeights
}
slog.Debug("memory", "estimate", s.mem)
s.mem.Log(slog.LevelInfo)
// The llama engine uses mmap by default
s.loadRequest.UseMmap = true
// mmap has issues with partial offloading on metal
for _, g := range gpus {
if g.Library == "Metal" &&
uint64(s.options.NumGPU) > 0 &&
uint64(s.options.NumGPU) < s.totalLayers {
s.options.UseMMap = new(bool)
*s.options.UseMMap = false
}
}
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
}
if err := s.waitUntilRunnerLaunched(ctx); err != nil {
return nil, err
}
s.loadRequest.GPULayers = gpuLayers
resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit)
if err != nil {
return nil, err
}
// On the Ollama engine, we can print out a summary of the memory allocations.
// We don't have this for the llama engine but it does something similar itself.
if s.textProcessor != nil {
resp.Memory.Log(slog.LevelInfo)
}
if !resp.Success {
slog.Warn("failed to allocate memory for model", "memory", resp.Memory)
return nil, errors.New("failed to allocate memory for model")
}
// The llama engine does its memory allocations together with model loading, so we
// need to wait until it is done to ensure that we have accurate memory data before
// loading the next model
if s.textProcessor == nil {
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
} else {
return uniqueDeviceIDs(s.loadRequest.GPULayers), nil
}
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
}
// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment
// of particular layers onto GPUs
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus []ml.DeviceInfo, numGPU int) ml.GPULayersList {
if numGPU <= 0 || len(gpus) == 0 {
return nil
func projectorMemoryRequirements(filename string) (weights uint64) {
file, err := os.Open(filename)
if err != nil {
return 0
}
defer file.Close()
ggml, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}
gpuLayers := make(ml.GPULayersList, len(gpus))
for i := range gpuLayers {
gpuLayers[i].DeviceID = gpus[i].DeviceID
for _, layer := range ggml.Tensors().GroupLayers() {
weights += layer.Size()
}
var sum float32
splits := make([]float32, len(estimate.TensorSplit))
// cumulative sum of all splits
for i := range splits {
sum += float32(estimate.TensorSplit[i])
splits[i] = sum
}
if sum <= 0 {
return nil
}
// normalize splits
for i := range splits {
splits[i] /= sum
}
blocks := int(ggml.KV().BlockCount())
gpuRangeStart := max(0, blocks-numGPU)
gpuRangeStop := min(gpuRangeStart+numGPU, blocks+1)
for i := range blocks + 1 {
if i < gpuRangeStart || i >= gpuRangeStop {
continue
}
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
if index < 0 || index >= len(gpus) {
continue
}
gpuLayers[index].Layers = append(gpuLayers[index].Layers, i)
}
return gpuLayers
return weights
}
// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model
@@ -652,23 +701,6 @@ func (s *ollamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
systemTotalMemory := systemInfo.TotalMemory
systemFreeMemory := systemInfo.FreeMemory
systemSwapFreeMemory := systemInfo.FreeSwap
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
"available", format.HumanBytes2(available),
"free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
}
pastAllocations := make(map[uint64]struct{})
var backoff float32
@@ -834,25 +866,22 @@ func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
// - Assigning layers
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *ollamaServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
if memory == nil {
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}}
}
gpuLayers, layers, err := s.buildLayout(systemGPUs, memory, requireFull, backoff)
if err != nil {
return nil, err
}
err = s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
if err != nil {
return nil, err
}
return gpuLayers, nil
}
func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64, error) {
func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64) {
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
@@ -910,11 +939,11 @@ func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.Backen
gpuLayers = libraryGpuLayers
}
}
return gpuLayers, layers, nil
return gpuLayers, layers
}
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *ollamaServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
// These sizes will only increase as we go through additional iterations and get additional information.
cpuSize := memory.InputWeights + memory.CPU.Graph
var vramSize uint64
@@ -942,11 +971,13 @@ nextLayer:
if requireFull {
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
return ErrLoadRequiredFull
}
if cpuSize > systemInfo.FreeMemory {
return ErrLoadRequiredFull
slog.Info("model requires more system memory than is currently available, evicting a model to make space", "required", cpuSize, "free", systemInfo.FreeMemory)
return fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
}
}
@@ -976,6 +1007,13 @@ nextLayer:
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
// If the user is manually overriding parameters, treat all GPUs equally so they split according to VRAM
if requestedLayers >= 0 || envconfig.SchedSpread() {
for i := range gpus {
gpus[i].Integrated = false
}
}
// If we can't fit everything then prefer offloading layers other than the output layer
for range 2 {
// requestedLayers may be -1 if nothing was requested
@@ -1008,33 +1046,38 @@ func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, reque
// findBestFit binary searches to find the smallest capacity factor that can fit
// the max number of layers. The capacity factor is multiplied by the free space on
// each GPU and a small one will force even balancing.
// each GPU and a small one will force even balancing. Higher performance GPUs are
// used first.
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
var high float32 = 1
var low float32 = 0
for _, gl := range ml.ByPerformance(gpus) {
var high float32 = 1
var low float32 = 0
// If we need to fulfill the requested number of layers, pretend we have almost infinite VRAM
if requestedLayers >= 0 && forceRequest {
high = 1000
}
bestAssignments := greedyFit(layers, gpus, high, requestedLayers)
maxNumGPU := bestAssignments.Sum()
if maxNumGPU == 0 {
return bestAssignments
}
for high-low > 1e-6 {
mid := (low + high) / 2
assignments := greedyFit(layers, gpus, mid, requestedLayers)
if assignments.Sum() == maxNumGPU {
high = mid
bestAssignments = assignments
} else {
low = mid
// If we need to fulfill the requested number of layers, pretend we have almost infinite VRAM
if requestedLayers >= 0 && forceRequest {
high = 1000
}
bestAssignments := greedyFit(layers, gl, high, requestedLayers)
maxNumGPU := bestAssignments.Sum()
for high-low > 1e-6 {
mid := (low + high) / 2
assignments := greedyFit(layers, gl, mid, requestedLayers)
if assignments.Sum() == maxNumGPU {
high = mid
bestAssignments = assignments
} else {
low = mid
}
}
layers = layers[:len(layers)-bestAssignments.Sum()]
requestedLayers -= bestAssignments.Sum()
gpuLayers = append(bestAssignments, gpuLayers...)
}
return bestAssignments
return gpuLayers
}
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
@@ -1362,6 +1405,12 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess
Shift bool
Truncate bool
// Logprobs specifies whether to include log probabilities in the response
Logprobs bool
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
}
// DoneReason represents the reason why a completion response is done
@@ -1387,6 +1436,18 @@ func (d DoneReason) String() string {
}
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
@@ -1395,6 +1456,9 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1530,7 +1594,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
Content: c.Content,
Logprobs: c.Logprobs,
})
}
@@ -1623,68 +1688,59 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
return e.Embedding, nil
}
type TokenizeRequest struct {
Content string `json:"content"`
}
type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
if s.llamaModel != nil {
return s.llamaModel.Tokenize(content, false, true)
if s.llamaModel == nil {
return nil, fmt.Errorf("no tokenizer configured")
}
if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
return s.llamaModel.Tokenize(content, false, true)
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
// not reached
return nil, fmt.Errorf("no tokenizer configured")
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
}
type DetokenizeRequest struct {
Tokens []int `json:"tokens"`
}
type DetokenizeResponse struct {
Content string `json:"content"`
}
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
if s.llamaModel != nil {
var resp string
for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token)
}
return resp, nil
if s.llamaModel == nil {
return "", fmt.Errorf("no tokenizer configured")
}
if s.textProcessor != nil {
toks := make([]int32, len(tokens))
for i, t := range tokens {
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}
return content, nil
var resp string
for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token)
}
// not reached
return "", fmt.Errorf("no tokenizer configured")
return resp, nil
}
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
toks := make([]int32, len(tokens))
for i, t := range tokens {
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}
return content, nil
}
func (s *llmServer) Close() error {
@@ -1712,31 +1768,12 @@ func (s *llmServer) Close() error {
return nil
}
func (s *llamaServer) VRAMSize() uint64 {
return s.estimate.VRAMSize
}
func (s *llamaServer) TotalSize() uint64 {
return s.estimate.TotalSize
}
func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
for i, gpu := range s.gpus {
if gpu.DeviceID == id {
if i < len(s.estimate.GPUSizes) {
return s.estimate.GPUSizes[i]
}
}
}
return 0
}
func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
slog.Debug("llamarunner free vram reporting not supported")
return nil
}
func (s *ollamaServer) VRAMSize() uint64 {
func (s *llmServer) VRAMSize() uint64 {
if s.mem == nil {
return 0
}
@@ -1764,7 +1801,7 @@ func (s *ollamaServer) VRAMSize() uint64 {
return mem
}
func (s *ollamaServer) TotalSize() uint64 {
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
@@ -1778,7 +1815,7 @@ func (s *ollamaServer) TotalSize() uint64 {
return mem
}
func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
if s.mem == nil {
return 0
}

View File

@@ -14,16 +14,11 @@ import (
)
func TestLLMServerFitGPU(t *testing.T) {
type gpu struct {
id ml.DeviceID
free int
}
minMemory := 457 * format.MebiByte
tests := []struct {
name string
gpus []gpu
gpus []ml.DeviceInfo
layers []int
numGPU int
requireFull bool
@@ -38,91 +33,91 @@ func TestLLMServerFitGPU(t *testing.T) {
},
{
name: "Full single GPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
},
{
name: "Partial single GPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "Single GPU with numGPU 1",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
},
{
name: "Single GPU with numGPU 0",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 0,
expected: ml.GPULayersList{},
},
{
name: "Single GPU with numGPU 999",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
},
{
name: "Multi GPU fits on one",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
},
{
name: "Multi GPU split",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "Multi GPU partial",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 1",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 2",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 2,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 999",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "Multi GPU different libraries",
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
},
{
name: "requireFull",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
requireFull: true,
@@ -130,12 +125,54 @@ func TestLLMServerFitGPU(t *testing.T) {
},
{
name: "requireFull numGPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256 * format.MebiByte)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 4,
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
{
name: "iGPU",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
},
{
name: "iGPU + dGPU",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "iGPU + dGPU fits on one",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1}}},
},
{
name: "iGPU + dGPU partial",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "iGPU + dGPU numGPU 1",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "iGPU + dGPU numGPU 999",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1, 2, 3}}},
},
}
for _, tt := range tests {
@@ -145,12 +182,6 @@ func TestLLMServerFitGPU(t *testing.T) {
systemInfo.FreeMemory = 512 * format.MebiByte
systemInfo.FreeSwap = 256 * format.MebiByte
gpus := make([]ml.DeviceInfo, len(tt.gpus))
for i := range tt.gpus {
gpus[i].DeviceID = tt.gpus[i].id
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
}
s := &ollamaServer{
llmServer: llmServer{
totalLayers: uint64(len(tt.layers)),
@@ -165,19 +196,19 @@ func TestLLMServerFitGPU(t *testing.T) {
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
}, GPUs: make([]ml.DeviceMemory, len(tt.gpus))}
for i := range tt.layers {
s.mem.CPU.Weights[i] = uint64(tt.layers[i])
}
for i := range s.mem.GPUs {
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
s.mem.GPUs[i].DeviceID = tt.gpus[i].DeviceID
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0)
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
if err != tt.expectedErr {
t.Fatalf("fitGPU returned error: %v", err)
}

View File

@@ -1,16 +0,0 @@
{
"env": {
"browser": true,
"es6": true,
"node": true
},
"extends": [
"eslint:recommended",
"plugin:@typescript-eslint/eslint-recommended",
"plugin:@typescript-eslint/recommended",
"plugin:import/recommended",
"plugin:import/electron",
"plugin:import/typescript"
],
"parser": "@typescript-eslint/parser"
}

92
macapp/.gitignore vendored
View File

@@ -1,92 +0,0 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
.DS_Store
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# TypeScript v1 declaration files
typings/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
# parcel-bundler cache (https://parceljs.org/)
.cache
# next.js build output
.next
# nuxt.js build output
.nuxt
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# Webpack
.webpack/
# Vite
.vite/
# Electron-Forge
out/

View File

@@ -1,21 +0,0 @@
# Desktop
This app builds upon Ollama to provide a desktop experience for running models.
## Developing
First, build the `ollama` binary:
```shell
cd ..
go build .
```
Then run the desktop app with `npm start`:
```shell
cd macapp
npm install
npm start
```

View File

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 402 B

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 741 B

Some files were not shown because too many files have changed in this diff Show More