Compare commits

...

64 Commits

Author SHA1 Message Date
ParthSareen
233d5c5eda refactor(agent): implement three-tier approval system with warn patterns
- Remove git commands from auto-allowlist
- Add new warn patterns tier for commands requiring explicit approval
- Move network commands and env files from deny to warn
- Add IsWarn() and containsWord() helper functions
- Enhanced git prefix extraction for granular allowlisting
- Move credential path patterns to denyPathPatterns
- UI improvements: dynamic warning messages and allowlist info
- Update tests: add TestIsWarn(), adjust expectations
2026-01-09 00:10:10 -08:00
Daniel Hiltgen
33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00
Daniel Hiltgen
34d0c55ea5 Linux: switch to zstd compression (#13651)
With the upcoming addition of MLX, the linux bundle will exceed the
maximum github artifact size of 2G.  This change will bring the size
back down.

The install.sh changes support backwards compatibility for prior versions
thus should be safe to merge concurrently with this change.
2026-01-08 15:47:32 -08:00
Parth Sareen
53a5a9e9ae x: redesign agent UI with minimal styling (#13650) 2026-01-08 15:40:07 -08:00
Parth Sareen
e30e08a7d6 x: remove Ctrl+O tool output expansion feature (#13640) 2026-01-07 15:34:08 -08:00
Parth Sareen
12e2b3514a x: agent loop ux improvements (#13635) 2026-01-07 01:27:15 -08:00
Devon Rifkin
626af2d809 template: fix args-as-json rendering (#13636)
In #13525, I accidentally broke templates' ability to automatically
render tool call function arguments as JSON.

We do need these to be proper maps because we need templates to be able
to call range, which can't be done on custom types.
2026-01-06 18:33:57 -08:00
Parth Sareen
76912c062a x: add experimental agent loop (#13628) 2026-01-05 23:38:40 -08:00
Devon Rifkin
6c3faafed2 olmo3: fix flaky test (#13629)
I introduced this in <https://github.com/ollama/ollama/pull/13525>
2026-01-05 22:37:20 -08:00
Devon Rifkin
e51dead636 preserve tool definition and call JSON ordering (#13525)
* preserve tool definition and call JSON ordering

This is another iteration of
<https://github.com/ollama/ollama/pull/12518>, but this time we've
simplified things by relaxing the competing requirements of being
compatible AND order-preserving with templates (vs. renderers). We
maintain backwards compatibility at the cost of not guaranteeing order
for templates. We plan on moving more and more models to renderers,
which have been updated to use these new data types, and additionally
we could add an opt-in way of templates getting an order-preserved list
(e.g., via sibling template vars)

* orderedmap_test: remove testify
2026-01-05 18:03:36 -08:00
Harry V. Kiselev
d087e46bd1 docs/capabilities/vision: fix curl related code snippet (#13615) 2026-01-03 17:27:46 -05:00
lif
37f6f3af24 server: return error when embedding contains NaN or Inf values (#13599)
The normalize function now checks for NaN and Inf values in the
embedding vector before processing. This prevents JSON encoding
failures when models produce invalid floating-point values.

Fixes #13572

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 02:20:12 -05:00
Nhan Nguyen
e1bdc23dd2 docs: fix tool name mismatch and trailing commas in api.md example (#13559)
The tool calling example used "get_temperature" for tool_calls but
defined the tool as "get_weather". Also removed trailing commas that
made the JSON invalid.

Fixes #13031
2026-01-03 02:14:53 -05:00
lif
2e78653ff9 app/ui: add swift syntax highlighting support (#13574)
Fixes #13476

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 02:12:08 -05:00
lif
f5f74e12c1 docs: add version note for /v1/responses API (#13596)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 01:58:20 -05:00
Vallabh Mahajan
18fdcc94e5 docs: fix broken .md links and render issues (#13550) 2025-12-23 12:44:55 -05:00
Daniel Hiltgen
7ad036992f amd: use GTT on iGPUs on linux (#13196)
On Linux, look at the GTT memory information for iGPUs.
2025-12-23 09:30:05 -08:00
Jesse Gross
172b5924af llm: Avoid integer underflow on llama engine memory layout
On the llama engine, when we compute the memory layout, we reserve
a buffer to allow for some flexibility for incorrect estimates.
This is subtracted from GPU free memory and on GPUs with limited
memory, it may underflow.

Fixes #13494
2025-12-19 15:48:15 -08:00
Jeffrey Morgan
8852220f59 add REQUIRES command to Modelfile (#13361) 2025-12-18 13:21:29 -08:00
Parth Sareen
7325791599 parsers/renderers: functiongemma (#13521) 2025-12-18 07:55:37 -08:00
Grace
522c11a763 Revert "Omit args and params in tool function def and calls (#13516)" (#13518)
This reverts commit 0fadeffaee.
2025-12-17 19:06:56 -08:00
Grace
0fadeffaee Omit args and params in tool function def and calls (#13516) 2025-12-17 18:42:21 -08:00
Daniel Hiltgen
49a9c9ba6a GGML update to ec98e2002 (#13451)
* Revert "add support for NVIDIA Nemotron 3 Nano"

This reverts commit e7d2ae9d69.

* GGML update to 380b4c984

Remove MaskBatchPadding as GGML_KQ_MASK_PAD is no longer present (no
padding required)

* update to c45f89d55

* ec98e2002

solar pro needed more adjusting - needs verification

* review comments
2025-12-17 13:13:55 -08:00
Parth Sareen
1c094038bc types: add nested property support for tool definitions (#13508) 2025-12-17 11:54:09 -08:00
Grace
a013693f80 DeepseekV3 Family Parser (#13484) 2025-12-16 18:56:30 -08:00
Michael Yang
f6a016f49d revert granite-embedding (#13505) 2025-12-16 15:44:52 -08:00
Bruce MacDonald
45c4739374 types: ConfigV2 and RootFS (#13504)
Refactored the ConfigV2 and RootFS types from server/images.go to a new types/model/config.go file under the model package. Updated all references to use model.ConfigV2 and model.RootFS. This allows for use in other projects without worrying about compiling the c code in the llama package.
2025-12-16 15:18:17 -08:00
Michael Yang
2dd029de12 remove unnecessary code (#13502)
slog is already lazily evaluated so this code is completely redundant
2025-12-16 15:11:26 -08:00
Michael Yang
903b1fc97f use ollama engine for bert models (#13501)
register bpe tokenizer which enables granite-embedding
2025-12-16 11:29:19 -08:00
Parth Sareen
89eb795293 parsers/renderers: use think from user for nemotron (#13492) 2025-12-15 18:55:17 -08:00
Parth Sareen
7e3ea813c1 llama/parsers/renderers: nemotron 3 nano (#13489)
---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2025-12-15 18:00:08 -08:00
Grace
7b95087b9d Adding tool definitions to DeepseekV3 renderer (#13491) 2025-12-15 17:57:06 -08:00
Michael Yang
971d62595a fix: qwen2.5 vl rope (#13486)
* qwen25vl: bump max pixels

* qwen25vl: mrope

fix qwen2.5vl window

* qwen25vl: vision rope
2025-12-15 17:30:33 -08:00
Parth Sareen
ffbe8e076d model: add olmo3 and olmo3.1 (#13415) 2025-12-15 15:20:04 -08:00
Grace
2c639431b1 DeepseekV3 family renderer (#13180) 2025-12-15 14:50:52 -08:00
Nhan Nguyen
aacd1cb394 fix: define GGML_VERSION variables for proper SOVERSION expansion (#13469)
The ggml/src/CMakeLists.txt uses GGML_VERSION_MAJOR for the shared
library SOVERSION property, but these variables were not defined when
building from ollama's CMakeLists.txt.

This caused libggml-base.so to be named with a literal "SOVERSION"
suffix (libggml-base.so.SOVERSION) instead of the actual version
number (libggml-base.so.0).

The fix adds the required GGML_VERSION_* variables before including
the ggml subdirectory.

Fixes #13436
2025-12-15 14:42:15 -08:00
Parth Sareen
e3731fb160 renderers: add olmo3.1 and olmo3 fixes (#13447) 2025-12-15 11:26:43 -08:00
Eva H
8dbc9e7b68 app/ui: handle unspecified bind addresses and wait for server in ollama proxy (#13159) 2025-12-15 13:33:09 -05:00
Daniel Hiltgen
abe67acf8a Revert "Enable Ollama engine by default" (#13481)
This reverts commit 56f754f46b.
2025-12-15 09:55:45 -08:00
Jeffrey Morgan
4ff8a691bc model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453) 2025-12-12 17:51:56 -08:00
Jeffrey Morgan
1b308e1d2a model: fix global layer rope scale values for gemma 3 (#13452) 2025-12-12 16:29:01 -08:00
Daniel Hiltgen
bd6c1d6b49 flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
2025-12-12 13:27:19 -08:00
Jeffrey Morgan
3af5d3b738 model: force rope factor 1.0 for Gemma 3 (#13445) 2025-12-12 13:27:08 -08:00
Daniel Hiltgen
7730895158 Enable Ollama engine by default (#13443)
This changes the default behavior to use the Ollama engine for supported
models, while retaining the ability to disable the Ollama engine and
fall back to the Llama engine.  Models in the OllamaEngineRequired list
will always run on the Ollama engine.
2025-12-12 11:48:43 -08:00
Eva H
de9ecfd01c tidy up lint warnings on windows (#13430) 2025-12-12 11:43:35 -05:00
Eva H
95fdd8d619 fix: select and update models folder in settings (#13412) 2025-12-12 11:09:37 -05:00
Devon Rifkin
9f7822851c docs: add docs for v1/responses and rework openai compat section (#13416)
* docs: add docs for v1/responses and rework openai compat section

I reworked the examples to be separated by topic and to be fully
runnable (i.e., they now log output instead of just suggesting how a
call might be made).

We now use `<CodeGroup>`s so that each example has a dropdown on the
docs site for users to choose, which makes the examples a lot more
digestible (since you only see approx 1/3 of the code you used to).

I also added a new tool to extract code examples into files so that it's
easier to actually run them and check that they work.

## Example

```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```

Output:

```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

  - 01_basic.py
  - 01_basic.js
  - 01_basic.sh
  - 02_responses.py
  - 02_responses.js
  - 02_responses.sh
  - 03_vision.py
  - 03_vision.js
  - 03_vision.sh

Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

To run examples:

  cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
  npm install   # for JS examples

then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```

In the future we should consider actually running the examples in CI and
having some sort of acceptance test so we can automatically detect when
our examples break. So this is just a start in that direction.

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2025-12-11 17:39:40 -08:00
Parth Sareen
9b2035d194 openai: add tool call appending to previous assistant message (#13434)
* openai: add tool call appending to previous asst message

* add tests for thinking appending
2025-12-11 17:30:12 -08:00
Alexander Gusak
93d45d7a04 docs: fix link to modelfile.mdx (#13220) 2025-12-11 16:14:45 -08:00
JJ
709f842457 Update README.md (#13373)
Correct Markdown syntax for Swollama GitHub and DocC documentation links
2025-12-11 16:08:57 -08:00
Jeffrey Morgan
2dfb74410d model: fix rotary embeddings for ministral 3 (#13432) 2025-12-11 16:02:05 -08:00
Devon Rifkin
1eb5e75972 openai: add v1/responses support (#13351)
Only supporting the stateless part of the API.

Doc updates to come once this is shipped.

Closes: #9659
2025-12-11 15:37:10 -08:00
nicole pardal
3475d915cb embeddings: modified batch size (#13429)
This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch.
Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash.
This change ensures all tokens stay in one batch and prevents crashes.
Fixes: #12938 #13054

Co-authored-by: Jesse Gross <jesse@ollama.com>
2025-12-11 15:36:31 -08:00
Jeffrey Morgan
48e78e9be1 template: add yesterdayDate helper function (#13431) 2025-12-11 14:47:55 -08:00
Jeffrey Morgan
a838421ea3 model: conversion and hyperparameter fixes for ministral and devstral (#13424) 2025-12-11 13:04:00 -08:00
EasonLin
1c4e85b4df routes: add logprobs in tool calls (#13238) 2025-12-10 17:28:41 -08:00
Eloi Torrents
dac4f17fea cmd/bench: fix binary name in README (#13276) 2025-12-10 14:16:58 -08:00
Julia Scheaffer
56b8fb024c cmd/bench: fix options table in cmd/bench/README.md (#13216) 2025-12-10 14:07:48 -08:00
Gabe Goodhart
b95693056c feat: llama.cpp bump (17f7f4) for SSM performance improvements (#13408)
* feat: Bump llama.cpp to the latest master (17f7f4b)

This brings in significant improvements to prefill performance for all
models using the SSM_CONV and SSM_SCAN ops (granite4, jamba, falcon-h,
nemotron-h, Qwen3 Next) on Apple Metal.

See https://github.com/ggml-org/llama.cpp/pull/17876

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Update patches 1-4

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Update patches 5-12

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Update patches 13-18

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Update patch 20

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Update patches 21-31

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Sync vendored code

The two files I'm not sure about here are the swap from gemma3-iswa.cpp to
gemma3.cpp (I chose to include this because I think it's required), and the
inclusion of `ggml-zendnn.h` which I chose to omit.

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-12-10 12:59:27 -08:00
Eva H
c34fc64688 app/ui: use requestAnimationFrame to prevent bottom line cutoff in streaming thinking display (#13137) 2025-12-10 15:29:48 -05:00
Eva H
7cf6f18c1f app/ui: refactor to use Ollama endpoints for user auth and health checks (#13081) 2025-12-10 15:24:31 -05:00
Eva H
bbbb6b2a01 app/ui: fix model capabilities not updating after download completion (#13179) 2025-12-10 14:40:02 -05:00
nicole pardal
76f88caf43 nomic-embed-text:v2: model implementation (#13162) 2025-12-09 14:24:51 -08:00
Parth Sareen
2bccf8c624 renderers/parsers: olmo3 instruct (#13383) 2025-12-09 11:12:27 -08:00
435 changed files with 58278 additions and 10206 deletions

View File

@@ -68,6 +68,7 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -392,13 +393,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tgz
*.tar.zst
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -531,7 +532,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -54,6 +70,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
# Define GGML version variables for shared library SOVERSION
# These are required by ggml/src/CMakeLists.txt for proper library versioning
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 0)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -140,14 +163,48 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,6 +83,28 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -140,6 +162,21 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,7 +131,39 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
FROM base AS build
@@ -153,6 +185,8 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=7f8ef50cce40e3e7e4526a3696cb45658190e69a
FETCH_HEAD=ec98e2002
.PHONY: help
help:

View File

@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama)
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)

View File

@@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse

View File

@@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"math"
"os"
@@ -14,6 +15,7 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model"
)
@@ -227,13 +229,79 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
// ToolCallFunctionArguments holds tool call arguments in insertion order.
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
if t == nil || t.om == nil {
return "{}"
}
bts, _ := json.Marshal(t.om)
return string(bts)
}
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
@@ -282,12 +350,78 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt))
}
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
}
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
@@ -336,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string {
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {
@@ -553,6 +687,9 @@ type CreateRequest struct {
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -603,6 +740,7 @@ type ShowResponse struct {
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].

View File

@@ -11,6 +11,24 @@ import (
"github.com/stretchr/testify/require"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
@@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
input: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
Properties: testPropsMap(map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
},
}),
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
},
@@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
name: "no required",
input: ToolFunctionParameters{
Type: "object",
Properties: map[string]ToolProperty{
Properties: testPropsMap(map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
},
}),
},
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
},
@@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: ToolCallFunctionArguments{"message": "hi"},
Arguments: testArgs(map[string]any{"message": "hi"}),
}
data, err := json.Marshal(fn)
@@ -504,6 +522,116 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
}
}
func TestToolPropertyNestedProperties(t *testing.T) {
tests := []struct {
name string
input string
expected ToolProperty
}{
{
name: "nested object properties",
input: `{
"type": "object",
"description": "Location details",
"properties": {
"address": {
"type": "string",
"description": "Street address"
},
"city": {
"type": "string",
"description": "City name"
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{
"address": {
Type: PropertyType{"string"},
Description: "Street address",
},
"city": {
Type: PropertyType{"string"},
Description: "City name",
},
}),
},
},
{
name: "deeply nested properties",
input: `{
"type": "object",
"description": "Event",
"properties": {
"location": {
"type": "object",
"description": "Location",
"properties": {
"coordinates": {
"type": "object",
"description": "GPS coordinates",
"properties": {
"lat": {"type": "number", "description": "Latitude"},
"lng": {"type": "number", "description": "Longitude"}
}
}
}
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{
"location": {
Type: PropertyType{"object"},
Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{
"coordinates": {
Type: PropertyType{"object"},
Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}),
},
}),
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
// Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop)
require.NoError(t, err)
var prop2 ToolProperty
err = json.Unmarshal(data, &prop2)
require.NoError(t, err)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
})
}
}
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
@@ -515,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
Properties: testPropsMap(map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
},
}),
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
@@ -537,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s
return s
}(),
Properties: map[string]ToolProperty{},
Properties: testPropsMap(map[string]ToolProperty{}),
},
expected: "",
},
@@ -550,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@@ -273,10 +273,6 @@ func main() {
Handler: uiServer.Handler(),
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
// Start the UI server
slog.Info("starting ui server", "port", port)
go func() {
@@ -320,6 +316,17 @@ func main() {
slog.Debug("no URL scheme request to handle")
}
go func() {
slog.Debug("waiting for ollama server to be ready")
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
slog.Warn("ollama server not ready, continuing anyway", "error", err)
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
}()
osRun(cancel, hasCompletedFirstRun, startHidden)
slog.Info("shutting down desktop server")
@@ -361,7 +368,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
return false
}
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
if err != nil {
slog.Debug("failed to call local auth endpoint", "error", err)
return false

View File

@@ -191,13 +191,6 @@ func LaunchNewApp() {
C.launchApp(appName)
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem()

View File

@@ -263,11 +263,6 @@ func createLoginShortcut() error {
return nil
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() {
}

View File

@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
}
NSArray* urls = [panel URLs];
if(self->params->allowMultiple && [urls count] >= 1) {
if([urls count] == 0) {
return DLG_CANCEL;
}
if(self->params->allowMultiple) {
// For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf;
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
}
return DLG_OK;

View File

@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int
func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
}
func err() error {

View File

@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models
} else {
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
}
}
if settings.ContextLength > 0 {

View File

@@ -469,26 +469,24 @@ export class HealthResponse {
}
export class User {
id: string;
name: string;
email: string;
avatarURL: string;
plan: string;
bio: string;
firstName: string;
lastName: string;
overThreshold: boolean;
name: string;
bio?: string;
avatarurl?: string;
firstname?: string;
lastname?: string;
plan?: string;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"];
this.name = source["name"];
this.email = source["email"];
this.avatarURL = source["avatarURL"];
this.plan = source["plan"];
this.name = source["name"];
this.bio = source["bio"];
this.firstName = source["firstName"];
this.lastName = source["lastName"];
this.overThreshold = source["overThreshold"];
this.avatarurl = source["avatarurl"];
this.firstname = source["firstname"];
this.lastname = source["lastname"];
this.plan = source["plan"];
}
}
export class Attachment {

View File

@@ -15,7 +15,7 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE } from "./lib/config";
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,7 +27,6 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -42,44 +41,50 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
}
export async function fetchUser(): Promise<User | null> {
try {
const response = await fetch(`${API_BASE}/api/v1/me`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const userData: User = await response.json();
return userData;
}
return null;
} catch (error) {
console.error("Error fetching user:", error);
return null;
}
}
export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/v1/connect`, {
method: "GET",
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to fetch connect URL");
if (response.ok) {
const userData: User = await response.json();
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
}
return userData;
}
const data = await response.json();
return data.connect_url;
if (response.status === 401 || response.status === 403) {
return null;
}
throw new Error(`Failed to fetch user: ${response.status}`);
}
export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (response.status === 401) {
const data = await response.json();
if (data.signin_url) {
return data.signin_url;
}
}
throw new Error("Failed to fetch connect URL");
}
export async function disconnectUser(): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
const response = await fetch(`${API_BASE}/api/signout`, {
method: "POST",
headers: {
"Content-Type": "application/json",
@@ -389,7 +394,8 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function fetchHealth(): Promise<boolean> {
try {
const response = await fetch(`${API_BASE}/api/v1/health`, {
// Use the /api/version endpoint as a health check
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET",
headers: {
"Content-Type": "application/json",
@@ -398,7 +404,8 @@ export async function fetchHealth(): Promise<boolean> {
if (response.ok) {
const data = await response.json();
return data.healthy || false;
// If we get a version back, the server is healthy
return !!data.version;
}
return false;

View File

@@ -299,9 +299,9 @@ export default function Settings() {
</Button>
</div>
</div>
{user?.avatarURL && (
{user?.avatarurl && (
<img
src={user.avatarURL}
src={user.avatarurl}
alt={user?.name}
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
onError={(e) => {

View File

@@ -50,21 +50,33 @@ export default function Thinking({
// Position content to show bottom when collapsed
useEffect(() => {
if (isCollapsed && contentRef.current && wrapperRef.current) {
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
setHasOverflow(false);
}
requestAnimationFrame(() => {
if (!contentRef.current || !wrapperRef.current) return;
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
});
} else if (contentRef.current) {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
}, [thinking, isCollapsed]);
useEffect(() => {
if (activelyThinking && wrapperRef.current && !isCollapsed) {
// When expanded and actively thinking, scroll to bottom
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
}
}, [thinking, activelyThinking, isCollapsed]);
const handleToggle = () => {
setIsCollapsed(!isCollapsed);
setHasUserInteracted(true);

View File

@@ -7,6 +7,7 @@ import { createQueryBatcher } from "./useQueryBatcher";
import { useRefetchModels } from "./useModels";
import { useStreamingContext } from "@/contexts/StreamingContext";
import { useSettings } from "./useSettings";
import { getModelCapabilities } from "@/api";
export const useChats = () => {
return useQuery({
@@ -606,6 +607,24 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["staleModels"], newStaleMap);
queryClient.invalidateQueries({ queryKey: ["models"] });
// Fetch fresh capabilities for the downloaded model
getModelCapabilities(selectedModel.model)
.then((capabilities) => {
queryClient.setQueryData(
["modelCapabilities", selectedModel.model],
capabilities,
);
})
.catch((error) => {
console.error(
"Failed to fetch capabilities after download:",
error,
);
queryClient.invalidateQueries({
queryKey: ["modelCapabilities", selectedModel.model],
});
});
}
break;
}

View File

@@ -1,114 +0,0 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import { pullModel } from "@/api";
import { useSelectedModel } from "./useSelectedModel";
import { useSettings } from "./useSettings";
interface DownloadProgress {
status: string;
digest?: string;
total?: number;
completed?: number;
done?: boolean;
}
export function useDownloadModel(chatId?: string) {
const queryClient = useQueryClient();
const { selectedModel } = useSelectedModel(chatId);
const { setSettings } = useSettings();
const [downloadProgress, setDownloadProgress] =
useState<DownloadProgress | null>(null);
const [abortController, setAbortController] =
useState<AbortController | null>(null);
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
new Set(),
);
const mutation = useMutation({
mutationFn: async (modelName: string) => {
const controller = new AbortController();
setAbortController(controller);
setDownloadProgress({ status: "Starting download..." });
if (chatId) {
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
}
try {
for await (const progress of pullModel(modelName, controller.signal)) {
setDownloadProgress(progress);
if (progress.status === "success") {
// Update selected model to indicate it's now available locally
if (selectedModel && selectedModel.model === modelName) {
setSettings({ SelectedModel: modelName });
}
// Invalidate models query to refresh the list
await queryClient.invalidateQueries({ queryKey: ["models"] });
break;
}
}
} finally {
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
},
onSuccess: () => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
},
onError: (error: Error) => {
const status =
error.name === "AbortError" ? "Download cancelled" : "Download failed";
setDownloadProgress({ status, done: true });
// Clear error message after delay
const delay = error.name === "AbortError" ? 1500 : 3000;
setTimeout(() => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}, delay);
},
});
const cancelDownload = () => {
if (abortController) {
abortController.abort();
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
};
return {
downloadModel: mutation.mutate,
isDownloading:
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
downloadProgress:
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
error: mutation.error,
cancelDownload,
};
}

View File

@@ -1,29 +1,20 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
export function useUser() {
const queryClient = useQueryClient();
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
// Wait for initial data to be loaded
useEffect(() => {
const initialPromise = window.__initialUserDataPromise;
if (initialPromise) {
initialPromise.finally(() => {
setInitialDataLoaded(true);
});
} else {
setInitialDataLoaded(true);
}
}, []);
const userQuery = useQuery({
queryKey: ["user"],
queryFn: () => fetchUser(),
queryFn: async () => {
const result = await fetchUser();
return result;
},
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
initialData: null, // Start with null to prevent flashing
retry: 10,
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
refetchOnMount: true, // Always fetch when component mounts
});
// Mutation to refresh user data
@@ -49,14 +40,15 @@ export function useUser() {
},
});
const isLoading = userQuery.isLoading || userQuery.isFetching;
const isAuthenticated = Boolean(userQuery.data?.name);
return {
user: userQuery.data,
isLoading:
!initialDataLoaded ||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
isLoading,
isError: userQuery.isError,
error: userQuery.error,
isAuthenticated: Boolean(userQuery.data?.name),
isAuthenticated,
refreshUser: refreshUser.mutate,
isRefreshing: refreshUser.isPending,
refetchUser: userQuery.refetch,

View File

@@ -8,3 +8,6 @@ export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;
export const OLLAMA_DOT_COM =
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";

View File

@@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
"c",
"cpp",
"sql",
"swift",
"yaml",
"markdown",
],

View File

@@ -5,13 +5,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { routeTree } from "./routeTree.gen";
import { fetchUser } from "./api";
import { StreamingProvider } from "./contexts/StreamingContext";
import { User } from "@/gotypes";
declare global {
interface Window {
__initialUserDataPromise?: Promise<User | null>;
}
}
const queryClient = new QueryClient({
defaultOptions: {
@@ -24,27 +17,11 @@ const queryClient = new QueryClient({
},
});
// Track initial user data fetch
let initialUserDataPromise: Promise<User | null> | null = null;
// Initialize user data on app startup
const initializeUserData = async () => {
try {
const userData = await fetchUser();
fetchUser().then((userData) => {
if (userData) {
queryClient.setQueryData(["user"], userData);
return userData;
} catch (error) {
console.error("Error initializing user data:", error);
queryClient.setQueryData(["user"], null);
return null;
}
};
// Start initialization immediately and track the promise
initialUserDataPromise = initializeUserData();
// Export the promise so hooks can await it
window.__initialUserDataPromise = initialUserDataPromise;
});
const router = createRouter({
routeTree,

View File

@@ -101,15 +101,14 @@ type HealthResponse struct {
}
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
Plan string `json:"plan"`
Bio string `json:"bio"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
OverThreshold bool `json:"overThreshold"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
}
type Attachment struct {

View File

@@ -12,18 +12,17 @@ import (
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools"
@@ -118,40 +117,66 @@ func (s *Server) log() *slog.Logger {
// ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler {
ollamaHost := os.Getenv("OLLAMA_HOST")
if ollamaHost == "" {
ollamaHost = "http://127.0.0.1:11434"
}
var (
proxy http.Handler
proxyMu sync.Mutex
)
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
ollamaHost = "http://" + ollamaHost
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyMu.Lock()
p := proxy
proxyMu.Unlock()
target, err := url.Parse(ollamaHost)
if err != nil {
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
})
}
if p == nil {
proxyMu.Lock()
if proxy == nil {
var err error
for i := range 2 {
if i > 0 {
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
time.Sleep(1 * time.Second)
}
s.log().Info("configuring ollama proxy", "target", target.String())
err = WaitForServer(context.Background(), 10*time.Second)
if err == nil {
break
}
}
proxy := httputil.NewSingleHostReverseProxy(target)
if err != nil {
proxyMu.Unlock()
s.log().Error("ollama server not ready after retries", "error", err)
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
return
}
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String())
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
newProxy := httputil.NewSingleHostReverseProxy(target)
return proxy
originalDirector := newProxy.Director
newProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
proxy = newProxy
p = newProxy
} else {
p = proxy
}
proxyMu.Unlock()
}
p.ServeHTTP(w, r)
})
}
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
@@ -264,11 +289,10 @@ func (s *Server) Handler() http.Handler {
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/v1/me", handle(s.me))
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
mux.Handle("GET /api/v1/health", handle(s.health))
mux.Handle("GET /api/version", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy)
mux.Handle("POST /api/me", ollamaProxy)
mux.Handle("POST /api/signout", ollamaProxy)
// React app - catch all non-API routes and serve the React app
mux.Handle("GET /", s.appHandler())
@@ -338,7 +362,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
}
// UserData fetches user data from ollama.com API for the current ollama key
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
@@ -349,7 +373,7 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var user responses.User
var user api.UserResponse
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse user response: %w", err)
}
@@ -368,29 +392,27 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
return &user, nil
}
func waitForServer(ctx context.Context) error {
timeout := time.Now().Add(10 * time.Second)
// TODO: this avoids an error on first load of the app
// however we should either show a loading state or
// wait for the Ollama server to be ready before redirecting
for {
// WaitForServer waits for the Ollama server to be ready
func WaitForServer(ctx context.Context, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
c, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := c.Version(ctx); err == nil {
break
}
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
slog.Debug("ollama server is ready")
return nil
}
time.Sleep(10 * time.Millisecond)
}
return nil
return errors.New("timeout waiting for Ollama server to be ready")
}
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
waitForServer(r.Context())
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
return err
}
id, err := uuid.NewV7()
if err != nil {
@@ -975,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed
toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
if err != nil {
errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil)
@@ -1438,129 +1460,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
})
}
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
user, err := s.UserData(r.Context())
if err != nil {
// If fetching from API fails, try to return cached user data if available
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
s.log().Info("API request failed, returning cached user data", "error", err)
responseUser := &responses.User{
Name: cachedUser.Name,
Email: cachedUser.Email,
Plan: cachedUser.Plan,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responseUser)
}
s.log().Error("failed to get user data", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get user data",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(user)
}
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
if err := s.Store.ClearUser(); err != nil {
s.log().Warn("failed to clear cached user data", "error", err)
}
// Get the SSH public key to encode for the delete request
pubKey, err := ollamaAuth.GetPublicKey()
if err != nil {
s.log().Error("failed to get public key", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get public key",
})
}
// Encode the key using base64 URL encoding
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
if err != nil {
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.log().Error("disconnect request failed", "status", resp.StatusCode)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
}
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
if err != nil {
s.log().Error("failed to build connect URL", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to build connect URL",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{
"connect_url": connectURL,
})
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
healthy := false
c, err := api.ClientFromEnvironment()
if err == nil {
if _, err := c.Version(r.Context()); err == nil {
healthy = true
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responses.HealthResponse{
Healthy: healthy,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
@@ -1659,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok {
@@ -1673,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""),
}
tool.Function.Parameters.Properties[propName] = prop
tool.Function.Parameters.Properties.Set(propName, prop)
}
}
}

View File

@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread
l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
t.app.UIRun(path)
case WM_COPYDATA:
// Handle URL scheme requests from other instances
if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
if cds.DwData == 1 { // Our identifier for URL scheme messages
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string
data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
urlScheme := string(data)
handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success

View File

@@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa
```
go build -o ollama-bench bench.go
./bench -model gpt-oss:20b -epochs 6 -format csv
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
@@ -29,31 +29,32 @@ go run bench.go -model gpt-oss:20b -epochs 3
### Basic Example
```
./bench -model gemma3 -epochs 6
./ollama-bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |

View File

@@ -45,6 +45,7 @@ import (
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -517,6 +518,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
@@ -543,6 +548,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@@ -943,6 +953,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
}
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return
})
@@ -1751,6 +1764,8 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -291,6 +291,31 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("min version", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Requires: "0.14.0",
}, false, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
requires 0.14.0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {

View File

@@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")

View File

@@ -6,11 +6,14 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,8 +21,13 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
} `json:"text_config"`
}
@@ -33,8 +41,94 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p AdapterParameters) KV() ggml.KV {
func (p AdapterParameters) KV() KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
alpha = p.LoraParameters.Alpha
}
kv := ggml.KV{
kv := KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelConverter interface {
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -107,7 +206,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
KV(ofs.Config) KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -115,7 +214,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
arch := baseKV.Architecture()
if arch == "" {
return errors.New("architecture not set for the base model")
}
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
return nil, nil, err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
return nil, nil, err
}
if len(p.Architectures) < 1 {
return errors.New("unknown architecture")
return nil, nil, errors.New("unknown architecture")
}
var conv ModelConverter
@@ -182,6 +277,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
@@ -200,8 +297,12 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel":
conv = &bertModel{}
case "NomicBertModel", "NomicBertMoEModel":
conv = &nomicbertModel{}
case "CohereForCausalLM":
conv = &commandrModel{}
case "GptOssForCausalLM":
@@ -211,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
return nil, nil, err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
return nil, nil, err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return err
return nil, nil, err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -248,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -257,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
func (p *bertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
func (p *commandrModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
func (p *deepseek2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
func (m *deepseekocr) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,7 +1,5 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -9,7 +7,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,6 +6,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,8 +3,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -55,7 +53,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
func (p *gemma3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
func (m *gemma3nModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
func (m *gptossModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
func (p *llamaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
func (p *llama4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,6 +7,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -18,13 +19,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
return kv
}

View File

@@ -30,13 +30,15 @@ type mistral3Model struct {
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
@@ -50,12 +52,15 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
func (p *mistral3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
@@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
}
// Vision configuration
@@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@@ -0,0 +1,181 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
func (p *mixtralModel) KV(t *Tokenizer) KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
func (m *mllamaModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -0,0 +1,213 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type nomicbertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeFreqBase float32 `json:"rope_theta"`
normalizeEmbeddings bool
PoolingType uint32
// MoE parameters (only present in v2 models)
NumExperts uint32 `json:"num_local_experts"`
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
}
var (
_ ModelConverter = (*nomicbertModel)(nil)
_ moreParser = (*nomicbertModel)(nil)
)
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)
arch := "nomic-bert"
if p.MoEEveryNLayers > 0 {
arch += "-moe"
}
kv["general.architecture"] = arch
kv["attention.causal"] = false
kv["pooling_type"] = p.PoolingType
kv["normalize_embeddings"] = p.normalizeEmbeddings
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
kv["context_length"] = contextLength
}
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
kv["embedding_length"] = p.HiddenSize
}
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
kv["feed_forward_length"] = p.IntermediateSize
}
if headCount := p.NumAttentionHeads; headCount > 0 {
kv["attention.head_count"] = p.NumAttentionHeads
}
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
kv["attention.head_count_kv"] = p.NumKeyValueHeads
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.RopeFreqBase > 0 {
kv["rope.freq_base"] = p.RopeFreqBase
}
// MoE specific parameters (only if MoE is enabled)
if p.NumExperts > 0 {
kv["expert_count"] = p.NumExperts
}
if p.NumExpertsUsed > 0 {
kv["expert_used_count"] = p.NumExpertsUsed
}
if p.MoEEveryNLayers > 0 {
kv["moe_every_n_layers"] = p.MoEEveryNLayers
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
switch {
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
// noop - keep special tokens as-is
case strings.HasPrefix(e, "##"):
t.Tokens[i] = e[2:]
default:
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (nomicbertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"attention.self.qkv", "attn_qkv",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"mlp.up", "ffn_up",
"mlp.down", "ffn_down",
"mlp.router", "ffn_gate_inp",
"mlp.experts.up", "ffn_up_exps",
"mlp.experts.down", "ffn_down_exps",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}

117
convert/convert_olmo.go Normal file
View File

@@ -0,0 +1,117 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
kv["olmo3.embedding_length"] = p.HiddenSize
kv["olmo3.feed_forward_length"] = p.IntermediateSize
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo3.rope.freq_base"] = p.RopeTheta
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
}
}
if p.RMSNormEPS > 0 {
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.SlidingWindow > 0 {
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
func (p *phi3Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen2Model) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
func (q *qwen3Model) KV(t *Tokenizer) KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,6 +19,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -28,7 +29,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k, v := range kv {
for k := range kv.Keys() {
v := kv.Value(k)
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV map[string]any
BaseKV KV
Expected map[string]string
}

View File

@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
}

View File

@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
@@ -507,7 +507,7 @@ The `message` object has the following fields:
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [
{
"function": {
"name": "get_temperature",
"name": "get_weather",
"arguments": {
"city": "Toronto"
}
},
}
}
]
},
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_temperature",
"tool_name": "get_weather"
}
],
"stream": false,
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
- `template`: (optional) the prompt template for the model
- `license`: (optional) a string or list of strings containing the license or licenses for the model
- `system`: (optional) a string containing the system prompt for the model
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
- `messages`: (optional) a list of message objects used to create a conversation
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples

View File

File diff suppressed because one or more lines are too long

View File

@@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}],
"stream": false
}'
"
```
</Tab>
<Tab title="Python">

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu.md).
Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library:
> [!NOTE]
> **NOTE:**
> Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> [!NOTE]
> **NOTE:**
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
described in the [FAQ](faq#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -20,8 +20,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
Start Ollama:
@@ -41,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
```
### ARM64 install
@@ -50,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -146,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
## Installing specific versions

View File

@@ -41,6 +41,7 @@ INSTRUCTION arguments
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
## Examples
@@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
### REQUIRES
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
```
REQUIRES <version>
```
The version should be a valid Ollama version (e.g. 0.14.0).
## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View File

@@ -0,0 +1,46 @@
# extract-examples
Extracts code examples from MDX files to a temp directory so you can run them.
## Usage
```shell
go run docs/tools/extract-examples/main.go <mdx-file>
```
## Example
```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```
Output:
```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
- 01_basic.py
- 01_basic.js
- 01_basic.sh
- 02_responses.py
- 02_responses.js
- 02_responses.sh
- 03_vision.py
- 03_vision.js
- 03_vision.sh
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
To run examples:
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
npm install # for JS examples
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```
## How it works
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
- Writes all extracted files to a temp directory

View File

@@ -0,0 +1,137 @@
package main
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
os.Exit(1)
}
mdxFile := os.Args[1]
f, err := os.Open(mdxFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
defer f.Close()
// Create temp directory
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
os.Exit(1)
}
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
// Patterns
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
codeGroupStart := regexp.MustCompile("^<CodeGroup")
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
scanner := bufio.NewScanner(f)
inCodeBlock := false
inCodeGroup := false
var currentFile string
var content strings.Builder
count := 0
codeGroupNum := 0
for scanner.Scan() {
line := scanner.Text()
// Track CodeGroup boundaries
if codeGroupStart.MatchString(line) {
inCodeGroup = true
codeGroupNum++
continue
}
if codeGroupEnd.MatchString(line) {
inCodeGroup = false
continue
}
if inCodeBlock {
if line == "```" {
// End of code block - write file
if currentFile != "" {
outPath := filepath.Join(tempDir, currentFile)
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
} else {
fmt.Printf(" - %s\n", currentFile)
count++
}
}
inCodeBlock = false
currentFile = ""
content.Reset()
} else {
content.WriteString(line)
content.WriteString("\n")
}
} else {
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
inCodeBlock = true
filename := matches[2]
// Prefix with CodeGroup number if inside a CodeGroup
if inCodeGroup {
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
} else {
currentFile = filename
}
content.Reset()
}
}
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
os.Exit(1)
}
// Write package.json for JavaScript dependencies
packageJSON := `{
"name": "mdx-examples",
"type": "module",
"dependencies": {
"openai": "^4",
"ollama": "^0.5"
}
}
`
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
}
// Write pyproject.toml for Python dependencies
pyprojectTOML := `[project]
name = "mdx-examples"
version = "0.0.0"
dependencies = [
"openai",
"ollama",
]
`
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
}
fmt.Printf("\n")
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
fmt.Printf("\n")
fmt.Printf("To run examples:\n")
fmt.Printf("\n")
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
fmt.Printf("\n")
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
}

View File

@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

View File

@@ -1,5 +1,7 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -11,4 +13,8 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,13 +6,16 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
)
type GGML struct {
@@ -238,20 +241,34 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",
"deepseek2",
"deepseekocr",
"gemma3",
"gemma3n",
"gptoss", "gpt-oss",
"llama4",
"mistral3",
"mllama",
"nomic-bert",
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}
@@ -550,7 +567,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
@@ -791,7 +808,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
if useFlashAttention == ml.FlashAttentionEnabled {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
@@ -809,6 +826,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
@@ -829,9 +854,11 @@ func (f GGML) SupportsFlashAttention() bool {
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"gptoss", "gpt-oss",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
return err
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
return err
}
}

19
go.mod
View File

@@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sys v0.36.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
)
require (
@@ -28,13 +28,17 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
gonum.org/v1/gonum v0.15.0
)
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
@@ -44,6 +48,7 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
@@ -76,11 +81,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
golang.org/x/net v0.46.0 // indirect
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

39
go.sum
View File

@@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
@@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
@@ -224,8 +233,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -255,6 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -267,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -278,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -295,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -319,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
}
}
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()
// Test with progressively larger inputs
testCases := []struct {
name string
inputWords int
}{
{"medium_input_256_words", 256},
{"large_input_512_words", 512},
{"very_large_input_800_words", 800},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
words := make([]string, tc.inputWords)
for i := range words {
words[i] = "word"
}
input := strings.Join(words, " ")
req := api.EmbedRequest{
Model: model,
Input: input,
KeepAlive: &api.Duration{Duration: 30 * time.Second},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) == 0 {
t.Fatal("expected non-empty embedding")
}
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
})
}
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler

View File

@@ -11,6 +11,15 @@ import (
"github.com/ollama/ollama/api"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second
@@ -57,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
}),
},
},
},

View File

@@ -0,0 +1,94 @@
// Package orderedmap provides a generic ordered map that maintains insertion order.
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
package orderedmap
import (
"encoding/json"
"iter"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
// Map is a generic ordered map that maintains insertion order.
type Map[K comparable, V any] struct {
om *orderedmap.OrderedMap[K, V]
}
// New creates a new empty ordered map.
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
om: orderedmap.New[K, V](),
}
}
// Get retrieves a value by key.
func (m *Map[K, V]) Get(key K) (V, bool) {
if m == nil || m.om == nil {
var zero V
return zero, false
}
return m.om.Get(key)
}
// Set sets a key-value pair. If the key already exists, its value is updated
// but its position in the iteration order is preserved. If the key is new,
// it is appended to the end.
func (m *Map[K, V]) Set(key K, value V) {
if m == nil {
return
}
if m.om == nil {
m.om = orderedmap.New[K, V]()
}
m.om.Set(key, value)
}
// Len returns the number of entries.
func (m *Map[K, V]) Len() int {
if m == nil || m.om == nil {
return 0
}
return m.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (m *Map[K, V]) All() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
if m == nil || m.om == nil {
return
}
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
// ToMap converts to a regular Go map.
// Note: The resulting map does not preserve order.
func (m *Map[K, V]) ToMap() map[K]V {
if m == nil || m.om == nil {
return nil
}
result := make(map[K]V, m.om.Len())
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
result[pair.Key] = pair.Value
}
return result
}
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
if m == nil || m.om == nil {
return []byte("null"), nil
}
return json.Marshal(m.om)
}
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
// order of keys in the JSON input.
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
m.om = orderedmap.New[K, V]()
return json.Unmarshal(data, &m.om)
}

View File

@@ -0,0 +1,348 @@
package orderedmap
import (
"encoding/json"
"slices"
"testing"
)
func TestMap_BasicOperations(t *testing.T) {
m := New[string, int]()
// Test empty map
if m.Len() != 0 {
t.Errorf("expected Len() = 0, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on empty map to return false")
}
if v != 0 {
t.Errorf("expected zero value, got %d", v)
}
// Test Set and Get
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
if m.Len() != 3 {
t.Errorf("expected Len() = 3, got %d", m.Len())
}
v, ok = m.Get("a")
if !ok || v != 1 {
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("b")
if !ok || v != 2 {
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("c")
if !ok || v != 3 {
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
}
// Test updating existing key preserves position
m.Set("a", 10)
v, ok = m.Get("a")
if !ok || v != 10 {
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
}
if m.Len() != 3 {
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
}
}
func TestMap_InsertionOrderPreserved(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
m.Set("b", 4)
// Verify iteration order matches insertion order
var keys []string
var values []int
for k, v := range m.All() {
keys = append(keys, k)
values = append(values, v)
}
expectedKeys := []string{"z", "a", "m", "b"}
expectedValues := []int{1, 2, 3, 4}
if !slices.Equal(keys, expectedKeys) {
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
}
if !slices.Equal(values, expectedValues) {
t.Errorf("expected values %v, got %v", expectedValues, values)
}
}
func TestMap_UpdatePreservesPosition(t *testing.T) {
m := New[string, int]()
m.Set("first", 1)
m.Set("second", 2)
m.Set("third", 3)
// Update middle element
m.Set("second", 20)
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
// Order should still be first, second, third
expected := []string{"first", "second", "third"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
// JSON should preserve insertion order, not alphabetical
expected := `{"z":1,"a":2,"m":3}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
// JSON with non-alphabetical key order
jsonData := `{"z":1,"a":2,"m":3}`
m := New[string, int]()
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
// Verify iteration order matches JSON order
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
expected := []string{"z", "a", "m"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_JSONRoundTrip(t *testing.T) {
// Test that unmarshal -> marshal produces identical JSON
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
m := New[string, string]()
if err := json.Unmarshal([]byte(original), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != original {
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
}
}
func TestMap_ToMap(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
regular := m.ToMap()
if len(regular) != 2 {
t.Errorf("expected len 2, got %d", len(regular))
}
if regular["a"] != 1 {
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
}
if regular["b"] != 2 {
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
}
}
func TestMap_NilSafety(t *testing.T) {
var m *Map[string, int]
// All operations should be safe on nil
if m.Len() != 0 {
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on nil map to return false")
}
if v != 0 {
t.Errorf("expected zero value from nil map, got %d", v)
}
// Set on nil is a no-op
m.Set("a", 1)
if m.Len() != 0 {
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
}
// All returns empty iterator
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
if len(keys) != 0 {
t.Errorf("expected empty iteration on nil map, got %v", keys)
}
// ToMap returns nil
if m.ToMap() != nil {
t.Error("expected ToMap to return nil on nil map")
}
// MarshalJSON returns null
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "null" {
t.Errorf("expected null, got %s", string(data))
}
}
func TestMap_EmptyMapMarshal(t *testing.T) {
m := New[string, int]()
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "{}" {
t.Errorf("expected {}, got %s", string(data))
}
}
func TestMap_NestedValues(t *testing.T) {
m := New[string, any]()
m.Set("string", "hello")
m.Set("number", 42)
m.Set("bool", true)
m.Set("nested", map[string]int{"x": 1})
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_AllIteratorEarlyExit(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
m.Set("d", 4)
// Collect only first 2
var keys []string
for k := range m.All() {
keys = append(keys, k)
if len(keys) == 2 {
break
}
}
expected := []string{"a", "b"}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_IntegerKeys(t *testing.T) {
m := New[int, string]()
m.Set(3, "three")
m.Set(1, "one")
m.Set(2, "two")
var keys []int
for k := range m.All() {
keys = append(keys, k)
}
// Should preserve insertion order, not numerical order
expected := []int{3, 1, 2}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_UnmarshalIntoExisting(t *testing.T) {
m := New[string, int]()
m.Set("existing", 999)
// Unmarshal should replace contents
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
_, ok := m.Get("existing")
if ok {
t.Error("existing key should be gone after unmarshal")
}
v, ok := m.Get("new")
if !ok || v != 1 {
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
}
}
func TestMap_LargeOrderPreservation(t *testing.T) {
m := New[string, int]()
// Create many keys in specific order
keys := make([]string, 100)
for i := range 100 {
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
if i >= 26 {
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
}
}
for i, k := range keys {
m.Set(k, i)
}
// Verify order preserved
var resultKeys []string
for k := range m.All() {
resultKeys = append(resultKeys, k)
}
if !slices.Equal(keys, resultKeys) {
t.Error("large map should preserve insertion order")
}
}

View File

@@ -140,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
@@ -364,15 +360,12 @@ func roundUp(length, pad int) int {
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
mask := make([]float32, c.curBatchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
@@ -386,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
}
}
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
if c.config.MaskDType != ml.DTypeF32 {
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "7f8ef50cce40e3e7e4526a3696cb45658190e69a";
char const *LLAMA_COMMIT = "ec98e2002";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -17,6 +17,9 @@ include /tools/mtmd/clip.cpp
include /tools/mtmd/mtmd.cpp
include /tools/mtmd/mtmd-audio.cpp
include /tools/mtmd/mtmd-helper.cpp
include /tools/mtmd/models/
include /tools/mtmd/models/*.h
include /tools/mtmd/models/*.cpp
include /src/
include /src/llama.*
include /src/llama-*.*

View File

@@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
bool fs_validate_filename(const std::string & filename) {
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
if (!filename.length()) {
// Empty filename invalid
return false;
@@ -754,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|| c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
if (!allow_subdirs && (c == '/' || c == '\\')) {
// Subdirectories not allowed, reject path separators
return false;
}
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -782,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) {
#include <iostream>
#ifdef _WIN32
static std::wstring utf8_to_wstring(const std::string & str) {
if (str.empty()) {
return std::wstring();
}
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
if (size <= 0) {
return std::wstring();
}
std::wstring wstr(size, 0);
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
return wstr;
}
#endif
// returns true if successful, false otherwise
bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
std::wstring wpath = utf8_to_wstring(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());
@@ -859,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
#endif // _WIN32
}
bool fs_is_directory(const std::string & path) {
std::filesystem::path dir(path);
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
}
std::string fs_get_cache_directory() {
std::string cache_directory = "";
auto ensure_trailing_slash = [](std::string p) {
@@ -893,6 +920,8 @@ std::string fs_get_cache_directory() {
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA");
#elif defined(__EMSCRIPTEN__)
GGML_ABORT("not implemented on this platform");
#else
# error Unknown architecture
#endif
@@ -912,7 +941,7 @@ std::string fs_get_cache_file(const std::string & filename) {
return cache_directory + filename;
}
std::vector<common_file_info> fs_list_files(const std::string & path) {
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
std::vector<common_file_info> files;
if (path.empty()) return files;
@@ -927,14 +956,22 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
const auto & p = entry.path();
if (std::filesystem::is_regular_file(p)) {
common_file_info info;
info.path = p.string();
info.name = p.filename().string();
info.path = p.string();
info.name = p.filename().string();
info.is_dir = false;
try {
info.size = static_cast<size_t>(std::filesystem::file_size(p));
} catch (const std::filesystem::filesystem_error &) {
info.size = 0;
}
files.push_back(std::move(info));
} else if (include_directories && std::filesystem::is_directory(p)) {
common_file_info info;
info.path = p.string();
info.name = p.filename().string();
info.size = 0; // Directories have no size
info.is_dir = true;
files.push_back(std::move(info));
}
} catch (const std::filesystem::filesystem_error &) {
// skip entries we cannot inspect
@@ -945,36 +982,71 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
return files;
}
//
// TTY utils
//
bool tty_can_use_colors() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
//
// Model utils
//
static inline void common_init_sampler_from_model(
// TODO: move to common/sampling
static void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};
@@ -1002,31 +1074,125 @@ static inline void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
struct common_init_result::impl {
impl() = default;
~impl() = default;
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers;
};
common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams;
return;
}
common_init_sampler_from_model(model, params.sampling);
pimpl->model.reset(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);
// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}
//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}
pimpl->samplers.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
}
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
return;
}
pimpl->context.reset(lctx);
}
llama_model * common_init_result::model() {
return pimpl->model.get();
}
llama_context * common_init_result::context() {
return pimpl->context.get();
}
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
return res;
}
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
@@ -1038,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
int err = llama_apply_adapter_cvec(
@@ -1052,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
}
@@ -1079,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (!ok) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
}
@@ -1092,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
char buf[1024];
@@ -1103,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) {
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}
if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
@@ -1178,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
iparams.context.reset(lctx);
return iparams;
return res;
}
common_init_result::~common_init_result() = default;
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1192,7 +1316,9 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
}
return model_endpoint;
}

View File

@@ -12,6 +12,10 @@
#include <vector>
#include <map>
#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
#endif
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
#else
@@ -26,8 +30,6 @@
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
} while(0)
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_time_meas {
common_time_meas(int64_t & t_acc, bool disable = false);
~common_time_meas();
@@ -80,7 +82,8 @@ int32_t cpu_get_num_math();
enum llama_example {
LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_COMPLETION,
LLAMA_EXAMPLE_CLI,
LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL,
@@ -96,6 +99,7 @@ enum llama_example {
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_FINETUNE,
LLAMA_EXAMPLE_FIT_PARAMS,
LLAMA_EXAMPLE_COUNT,
};
@@ -192,7 +196,6 @@ struct common_params_sampling {
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
@@ -213,6 +216,10 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool has_logit_bias() const {
return !logit_bias.empty();
}
// print the parameters into a string
std::string print() const;
};
@@ -223,6 +230,7 @@ struct common_params_model {
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string docker_repo = ""; // Docker repo // NOLINT
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
};
struct common_params_speculative {
@@ -299,8 +307,8 @@ struct lr_opt {
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
@@ -321,9 +329,12 @@ struct common_params {
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
@@ -369,7 +380,7 @@ struct common_params {
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
int32_t verbosity = 0;
int32_t verbosity = 3; // LOG_LEVEL_INFO
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
@@ -403,6 +414,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool no_perf = false; // disable performance metrics
bool show_timings = true; // show timing information on CLI
bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
@@ -459,7 +471,7 @@ struct common_params {
std::string public_path = ""; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool use_jinja = true; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int reasoning_budget = -1;
@@ -478,9 +490,16 @@ struct common_params {
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
bool log_json = false;
std::string slot_save_path;
std::string media_path; // path to directory for loading media files
float slot_prompt_similarity = 0.1f;
@@ -631,8 +650,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
// Filesystem utils
//
bool fs_validate_filename(const std::string & filename);
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
bool fs_create_directory_with_parents(const std::string & path);
bool fs_is_directory(const std::string & path);
std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);
@@ -641,22 +661,44 @@ struct common_file_info {
std::string path;
std::string name;
size_t size = 0; // in bytes
bool is_dir = false;
};
std::vector<common_file_info> fs_list_files(const std::string & path);
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
//
// TTY utils
//
// Auto-detect if colors can be enabled based on terminal and environment
bool tty_can_use_colors();
//
// Model utils
//
// note: defines object's lifetime
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
struct common_sampler;
std::vector<llama_adapter_lora_ptr> lora;
// note: defines the model, context, samplers, ets. lifetimes
struct common_init_result {
common_init_result(common_params & params);
~common_init_result();
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
struct common_init_result common_init_from_params(common_params & params);
using common_init_result_ptr = std::unique_ptr<common_init_result>;
common_init_result_ptr common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);

View File

@@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) {
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
class SchemaConverter {
class common_schema_converter {
private:
friend class common_schema_info;
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
@@ -729,7 +730,7 @@ private:
}
public:
SchemaConverter(
common_schema_converter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
@@ -974,7 +975,7 @@ public:
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
@@ -990,6 +991,134 @@ public:
}
};
// common_schema_info implementation (pimpl)
common_schema_info::common_schema_info()
: impl_(std::make_unique<common_schema_converter>(
[](const std::string &) { return json(); },
false)) {}
common_schema_info::~common_schema_info() = default;
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
impl_->resolve_refs(schema, "");
}
// Determines if a JSON schema can resolve to a string type through any path.
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
// true, allowing callers to handle the value as a raw string for simplicity.
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
std::unordered_set<std::string> visited_refs;
std::function<bool(const json &)> check = [&](const json & s) -> bool {
if (!s.is_object()) {
return false;
}
// Handle $ref
if (s.contains("$ref")) {
const std::string & ref = s["$ref"];
if (visited_refs.find(ref) != visited_refs.end()) {
// Circular reference, assume not a string to be safe
return false;
}
visited_refs.insert(ref);
auto it = impl_->_refs.find(ref);
if (it != impl_->_refs.end()) {
return check(it->second);
}
return false;
}
// Check type field
if (s.contains("type")) {
const json & schema_type = s["type"];
if (schema_type.is_string()) {
if (schema_type == "string") {
return true;
}
} else if (schema_type.is_array()) {
// Type can be an array like ["string", "null"]
for (const auto & t : schema_type) {
if (t == "string") {
return true;
}
}
}
}
// Check oneOf/anyOf - if any alternative can be a string
if (s.contains("oneOf")) {
for (const auto & alt : s["oneOf"]) {
if (check(alt)) {
return true;
}
}
}
if (s.contains("anyOf")) {
for (const auto & alt : s["anyOf"]) {
if (check(alt)) {
return true;
}
}
}
// Check allOf - all components must be compatible with string type
if (s.contains("allOf")) {
bool all_string = true;
for (const auto & component : s["allOf"]) {
if (!check(component)) {
all_string = false;
break;
}
}
if (all_string) {
return true;
}
}
// Check const - if the constant value is a string
if (s.contains("const")) {
if (s["const"].is_string()) {
return true;
}
}
// Check enum - if any enum value is a string
if (s.contains("enum")) {
for (const auto & val : s["enum"]) {
if (val.is_string()) {
return true;
}
}
}
// String-specific keywords imply string type
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
return true;
}
// Check format - many formats imply string
if (s.contains("format")) {
const std::string & fmt = s["format"];
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
fmt.find("uuid") == 0) {
return true;
}
}
return false;
};
return check(schema);
}
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
@@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);

View File

@@ -3,11 +3,31 @@
#include <nlohmann/json_fwd.hpp>
#include <functional>
#include <memory>
#include <string>
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
class common_schema_converter;
// Probes a JSON schema to extract information about its structure and type constraints.
class common_schema_info {
std::unique_ptr<common_schema_converter> impl_;
public:
common_schema_info();
~common_schema_info();
common_schema_info(const common_schema_info &) = delete;
common_schema_info & operator=(const common_schema_info &) = delete;
common_schema_info(common_schema_info &&) noexcept;
common_schema_info & operator=(common_schema_info &&) noexcept;
void resolve_refs(nlohmann::ordered_json & schema);
bool resolves_to_string(const nlohmann::ordered_json & schema);
};
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;

View File

@@ -1,3 +1,4 @@
#include "common.h"
#include "log.h"
#include <chrono>
@@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
@@ -391,7 +368,7 @@ struct common_log * common_log_main() {
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(common_log_should_use_colors_auto());
log.set_colors(tty_can_use_colors());
});
return &log;
@@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
void common_log_set_colors(struct common_log * log, log_colors colors) {
if (colors == LOG_COLORS_AUTO) {
log->set_colors(common_log_should_use_colors_auto());
log->set_colors(tty_can_use_colors());
return;
}
@@ -443,8 +420,27 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
log->set_timestamps(timestamps);
}
void common_log_flush(struct common_log * log) {
log->pause();
log->resume();
}
static int common_get_verbosity(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
case GGML_LOG_LEVEL_NONE:
default:
return LOG_LEVEL_OUTPUT;
}
}
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
auto verbosity = common_get_verbosity(level);
if (verbosity <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}

View File

@@ -21,8 +21,14 @@
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#define LOG_DEFAULT_DEBUG 1
#define LOG_DEFAULT_LLAMA 0
#define LOG_LEVEL_DEBUG 4
#define LOG_LEVEL_INFO 3
#define LOG_LEVEL_WARN 2
#define LOG_LEVEL_ERROR 1
#define LOG_LEVEL_OUTPUT 0 // output data from tools
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
enum log_colors {
LOG_COLORS_AUTO = -1,
@@ -67,16 +73,18 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
//
// I - info (stdout, V = 0)
// W - warning (stderr, V = 0)
// E - error (stderr, V = 0)
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
// I - info (stdout, V = LOG_DEFAULT_INFO)
// W - warning (stderr, V = LOG_DEFAULT_WARN)
// E - error (stderr, V = LOG_DEFAULT_ERROR)
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
//
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
void common_log_flush (struct common_log * log); // flush all pending log messages
// helper macros for logging
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
@@ -95,14 +103,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
} \
} while (0)
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)

View File

@@ -104,9 +104,10 @@ struct ring_buffer {
struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
bool grammar;
ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur;
@@ -116,7 +117,6 @@ struct common_sampler {
void reset() {
prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
}
@@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf;
struct llama_sampler * grmr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
bool grammar = false;
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
grammar = true;
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
@@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
trigger_patterns_c.push_back(regex.c_str());
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
if (!grmr) {
return nullptr;
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
samplers.push_back(
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()));
} else {
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
}
grammar = true;
}
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
@@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
for (auto * smpl : samplers) {
llama_sampler_chain_add(chain, smpl);
}
auto * result = new common_sampler {
/* .params = */ params,
/* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
return result;
}
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
@@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
if (gsmpl->grammar) {
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
llama_sampler_accept(gsmpl->chain, token);
for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token);
}
gsmpl->prev.push_back(token);
}
@@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
/* .params = */ gsmpl->params,
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}
@@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
gsmpl->set_logits(ctx, idx);
llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
const llama_token id = cur_p.data[cur_p.selected].id;
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
llama_sampler_apply(grmr, &single_token_data_array);
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
return cur_p.data[cur_p.selected].id;
return id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
@@ -442,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true);
@@ -454,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true);
@@ -464,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
@@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
result += std::string("-> ");
result += std::string(llama_sampler_name(smpl)) + " ";
}
return result;

View File

@@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation:
//
// - set logits
@@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
// generalized version of common_sampler_sample
//
@@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
@@ -107,3 +106,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);
struct common_sampler_deleter {
void operator()(common_sampler * s) { common_sampler_free(s); }
};
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;

View File

@@ -313,6 +313,7 @@ extern "C" {
bool check_tensors; // validate model tensor data
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
bool no_host; // bypass host buffer allowing extra buffers to be used
bool no_alloc; // only load metadata and simulate memory allocations
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
@@ -466,10 +467,24 @@ extern "C" {
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
// returns true if the parameters could be successfully modified to fit device memory
// this function is NOT thread safe because it modifies the global llama logger state
LLAMA_API bool llama_params_fit(
const char * path_model,
struct llama_model_params * mparams,
struct llama_context_params * cparams,
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
size_t margin, // margin of memory to leave per device in bytes
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
LLAMA_API int64_t llama_time_us(void);
LLAMA_API size_t llama_max_devices(void);
LLAMA_API size_t llama_max_parallel_sequences(void);
LLAMA_API size_t llama_max_tensor_buft_overrides(void);
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);
@@ -1354,7 +1369,9 @@ extern "C" {
// Set callback for all future logging events.
// If this is not called, or NULL is supplied, everything is output on stderr.
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
// The logger state is global so these functions are NOT thread safe.
LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data);
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
//
// Performance utils

View File

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
#include "ggml.h" // ggml_op
#include <string>
#include <set>
//
// gguf constants (sync with gguf.py)
@@ -79,6 +80,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H,
LLM_ARCH_NEMOTRON_H_MOE,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
@@ -116,6 +118,7 @@ enum llm_arch {
LLM_ARCH_COGVLM,
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};
@@ -209,6 +212,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@@ -315,6 +319,7 @@ enum llm_tensor {
LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ROPE_FACTORS_LONG,
LLM_TENSOR_ROPE_FACTORS_SHORT,
@@ -379,6 +384,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_DT_NORM,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN
LLM_TENSOR_SSM_B_NORM,
LLM_TENSOR_SSM_C_NORM,
LLM_TENSOR_SSM_D,
@@ -525,6 +531,10 @@ struct LLM_TN_IMPL {
const int bid;
const int xid;
const std::set<llm_tensor> model_tensors;
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
std::string str() const;
operator std::string() const {
@@ -546,11 +556,11 @@ struct LLM_TN {
llm_arch arch;
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
return { arch, tensor, suffix, bid, xid };
return LLM_TN_IMPL(arch, tensor, suffix, bid, xid);
}
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
return { arch, tensor, nullptr, bid, xid };
return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid);
}
};

View File

@@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
udata->output .resize(n_tokens);
udata->seq_id_data.reserve(n_tokens);
seq_set_t seq_set_unq;
for (size_t i = 0; i < idxs.size(); ++i) {
@@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
}
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
udata->seq_id[i] = batch.seq_id[idxs[i]];
udata->output[i] = batch.logits[idxs[i]];
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
seq_set_unq.set(udata->seq_id[i][s]);
const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
udata->seq_id_data.push_back(seq_id);
seq_set_unq.set(seq_id);
}
if (udata->output[i]) {
@@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
}
}
llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
for (size_t i = 0; i < idxs.size(); ++i) {
udata->seq_id[i] = seq_id_ptr;
seq_id_ptr += udata->n_seq_id[i];
}
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
udata->seq_idx[s] = udata->seq_id_unq.size();

View File

@@ -56,13 +56,15 @@ struct llama_ubatch {
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
std::vector<llama_seq_id> seq_id_data;
};
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
// the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
std::shared_ptr<data_t> data;
};

View File

@@ -9,6 +9,7 @@
#include "llama-model.h"
#include <cinttypes>
#include <cmath>
#include <cstring>
#include <limits>
#include <stdexcept>
@@ -72,6 +73,43 @@ llama_context::llama_context(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
if (cparams.yarn_ext_factor != 0) {
static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};
const float factor = 1.0f / cparams.rope_freq_scale;
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
if (hparams.rope_yarn_log_mul != 0.0f) {
// note: here we assume `mscale == 1.0f`
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
float mscale = 1.0f;
const float mscale_all_dims = hparams.rope_yarn_log_mul;
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// special-case DEEPSEEK v2:
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
mscale = mscale_all_dims;
}
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
} else {
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
}
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
//
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
// https://github.com/ggml-org/llama.cpp/pull/17945
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -93,14 +131,6 @@ llama_context::llama_context(
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
@@ -228,6 +258,7 @@ llama_context::llama_context(
backend_buft.clear();
backend_ptrs.clear();
backend_buf_exp_size.clear();
for (auto & backend : backends) {
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
@@ -244,11 +275,15 @@ llama_context::llama_context(
backend_buft.push_back(buft);
backend_ptrs.push_back(backend.get());
backend_buf_exp_size.push_back(0);
}
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
const size_t max_nodes = this->graph_max_nodes();
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
const size_t max_nodes = this->graph_max_nodes(n_tokens);
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
@@ -300,9 +335,6 @@ llama_context::llama_context(
cross.v_embd.clear();
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs;
@@ -359,7 +391,8 @@ llama_context::llama_context(
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
if (!gf) {
if (pipeline_parallel) {
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
@@ -377,7 +410,7 @@ llama_context::llama_context(
// reserve with tg (token generation) graph to get the number of splits and nodes
{
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
@@ -392,7 +425,7 @@ llama_context::llama_context(
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@@ -401,11 +434,13 @@ llama_context::llama_context(
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
if (size > 1) {
if (!model.hparams.no_alloc) {
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
}
if (backend_buf_exp_size[i] > 1) {
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
backend_buf_exp_size[i] / 1024.0 / 1024.0);
}
}
@@ -424,6 +459,23 @@ llama_context::llama_context(
}
llama_context::~llama_context() {
// FIXME this currently results in a use-after-free bug if the model is freed before the context
// if (!model.hparams.no_alloc) {
// for (size_t i = 0; i < backend_ptrs.size(); ++i) {
// ggml_backend_t backend = backend_ptrs[i];
// ggml_backend_buffer_type_t buft = backend_buft[i];
// const size_t size_exp = backend_buf_exp_size[i];
// const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
// if (size_exp == size_act) {
// LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
// } else {
// LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
// }
// }
// }
ggml_opt_free(opt_ctx);
}
@@ -1325,6 +1377,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
synchronize();
buf_output = nullptr;
logits = nullptr;
embd = nullptr;
@@ -1385,9 +1438,9 @@ void llama_context::output_reorder() {
// graph
//
uint32_t llama_context::graph_max_nodes() const {
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}
@@ -1396,7 +1449,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
ggml_cgraph * llama_context::graph_reserve(
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
GGML_ASSERT(n_outputs >= 1);
@@ -1433,8 +1487,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
// initialize scheduler with the specified graph
if (split_only) {
ggml_backend_sched_split_graph(sched.get(), gf);
if (sizes) {
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
} else {
ggml_backend_sched_split_graph(sched.get(), gf);
}
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
GGML_ASSERT(!sizes);
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}
@@ -2056,15 +2115,26 @@ void llama_context::perf_reset() {
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
for (const auto & buft_size : model.memory_breakdown()) {
ret[buft_size.first].model += buft_size.second;
for (const auto & [buft, size] : model.memory_breakdown()) {
ret[buft].model += size;
}
for (const auto & buft_size : memory->memory_breakdown()) {
ret[buft_size.first].context += buft_size.second;
if (memory) {
for (const auto & [buft, size] : memory->memory_breakdown()) {
ret[buft].context += size;
}
}
for (const auto & backend_ptr : backends) {
ggml_backend_t backend = backend_ptr.get();
ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
if (model.hparams.no_alloc) {
for (size_t i = 0; i < backends.size(); ++i) {
ggml_backend_t backend = backends[i].get();
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
ret[buft].compute += backend_buf_exp_size[i];
}
} else {
for (const auto & backend_ptr : backends) {
ggml_backend_t backend = backend_ptr.get();
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
}
}
return ret;
}

View File

@@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
size_t model = 0; // memory allocated for the model
size_t context = 0; // memory allocated for the context
size_t compute = 0; // memory allocated for temporary compute buffers
size_t total() const {
return model + context + compute;
}
};
struct llama_context {
@@ -197,7 +201,7 @@ private:
//
public:
uint32_t graph_max_nodes() const;
uint32_t graph_max_nodes(uint32_t n_tokens) const;
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
llm_graph_result * get_gf_res_reserve() const;
@@ -206,7 +210,8 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
ggml_cgraph * graph_reserve(
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
private:
llm_graph_params graph_params(
@@ -281,9 +286,10 @@ private:
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
// buffer types used for the compute buffer of each backend
// pointers and buffer types used for the compute buffer of each backend
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
llm_graph_result_ptr gf_res_prev;
llm_graph_result_ptr gf_res_reserve;

View File

@@ -181,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
throw std::runtime_error("unexpected end of input");
}
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
const char * pos = src;
if (*pos != '<') {
throw std::runtime_error(std::string("expecting '<' at ") + pos);
}
pos++;
// Parse <[id]>
if (*pos == '[') {
pos++;
const char * int_end = parse_int(pos);
uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
pos = int_end;
if (*pos != ']') {
throw std::runtime_error(std::string("expecting ']' at ") + pos);
}
pos++;
if (*pos != '>') {
throw std::runtime_error(std::string("expecting '>' at ") + pos);
}
pos++;
return std::make_pair(token_id, pos);
}
if (vocab == nullptr) {
throw std::runtime_error(std::string("no vocab to parse token at ") + src);
}
// Parse <token> and tokenize to obtain the token id
while (*pos != 0 && *pos != '>') {
pos++;
}
if (*pos != '>') {
throw std::runtime_error(std::string("expecting '>' at ") + pos);
}
pos++;
llama_token tokens[2];
int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
if (n_tokens != 1) {
// must tokenize to exactly 1 token
throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
}
return std::make_pair(tokens[0], pos);
}
static void print_grammar_char(FILE * file, uint32_t c) {
if (0x20 <= c && c <= 0x7f) {
fprintf(file, "%c", static_cast<char>(c));
@@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
}
switch (elem.type) {
case LLAMA_GRETYPE_END:
@@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
break;
case LLAMA_GRETYPE_TOKEN:
fprintf(file, "<[");
fprintf(file, "%u", elem.value);
fprintf(file, "]> ");
break;
case LLAMA_GRETYPE_TOKEN_NOT:
fprintf(file, "!");
fprintf(file, "<[");
fprintf(file, "%u", elem.value);
fprintf(file, "]> ");
break;
}
}
fprintf(file, "\n");
@@ -284,6 +343,17 @@ static void print_rule(
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, ".");
break;
case LLAMA_GRETYPE_TOKEN:
fprintf(file, "<[");
fprintf(file, "%u", elem.value);
fprintf(file, "]> ");
break;
case LLAMA_GRETYPE_TOKEN_NOT:
fprintf(file, "!");
fprintf(file, "<[");
fprintf(file, "%u", elem.value);
fprintf(file, "]> ");
break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
@@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
}
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '<' || *pos == '!') { // token
auto type = LLAMA_GRETYPE_TOKEN;
if (*pos == '!') { // token inverse
type = LLAMA_GRETYPE_TOKEN_NOT;
pos++;
}
auto token_pair = parse_token(vocab, pos);
const char * token_end = token_pair.second;
last_sym_start = rule.size();
rule.push_back({type, token_pair.first});
pos = parse_space(token_end, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
@@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char(
return !is_positive_char;
}
// returns true iff token matches the rule at pos (regular or inverse)
// asserts that pos is pointing to a token element
static bool llama_grammar_match_token(
const llama_grammar_element * pos,
const llama_token token) {
GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
if (pos->type == LLAMA_GRETYPE_TOKEN) {
return pos->value == static_cast<uint32_t>(token);
}
if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
return pos->value != static_cast<uint32_t>(token);
}
return false;
}
// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
@@ -738,6 +834,8 @@ static void llama_grammar_advance_stack(
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_ANY:
case LLAMA_GRETYPE_TOKEN:
case LLAMA_GRETYPE_TOKEN_NOT:
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have
new_stacks.emplace_back(stack);
@@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks;
}
static void llama_grammar_accept_chr(
struct llama_grammar & grammar,
const llama_grammar_stack & stack,
uint32_t chr,
llama_grammar_stacks & new_stacks) {
if (stack.empty()) {
return;
}
const llama_grammar_element * pos = stack.back();
// ignore if this turns into a token
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
return;
}
auto match = llama_grammar_match_char(pos, chr);
if (match.first) {
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(match.second)) {
new_stack.push_back(match.second);
}
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
}
}
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
llama_grammar_stacks stacks_new;
stacks_new.reserve(grammar->stacks.size());
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
continue;
}
auto match = llama_grammar_match_char(stack.back(), chr);
if (match.first) {
const llama_grammar_element * pos = match.second;
// update top of stack to next element, if any
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
}
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
}
grammar->stacks = std::move(stacks_new);
@@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
const llama_grammar_element * stack_pos = stack.back();
// if the top of the stack is a token rule, then we only need to check the token id
if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached the end of a token consumed by char rules, reject iff it ended
// in a partial response
if (tok.partial_utf8.n_remain != 0) {
rejects.push_back(tok);
}
} else if (!llama_grammar_match_token(stack_pos, tok.id)) {
rejects.push_back(tok);
}
}
return rejects;
}
llama_grammar_candidates next_candidates;
next_candidates.reserve(candidates.size());
@@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
rejects.push_back(tok);
}
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
} else {
rejects.push_back(tok);
}
@@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (const auto & tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
}
return rejects;
@@ -974,12 +1100,13 @@ struct llama_grammar * llama_grammar_init_impl(
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .lazy =*/ false,
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
/* .trigger_patterns = */ {},
/* .partial_utf8 = */ {},
/* .lazy = */ false,
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_buffer_positions = */ {},
/* .trigger_tokens = */ {},
/* .trigger_patterns = */ {},
};
}
@@ -993,7 +1120,7 @@ struct llama_grammar * llama_grammar_init_impl(
size_t num_trigger_patterns,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
llama_grammar_parser parser;
llama_grammar_parser parser(vocab);
// if there is a grammar, parse it
// rules will be empty (default) if there are parse errors
@@ -1081,10 +1208,11 @@ struct llama_grammar * llama_grammar_init_impl(
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
/* .partial_utf8 = */ {},
/* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
/* .trigger_buffer_positions = */ {},
std::move(vec_trigger_tokens),
std::move(vec_trigger_patterns),
};
@@ -1108,6 +1236,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.lazy,
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_buffer_positions,
grammar.trigger_tokens,
grammar.trigger_patterns,
};
@@ -1164,7 +1293,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
cur_p->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
}
}
@@ -1184,10 +1313,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, piece);
llama_grammar_accept_token(grammar, token, piece);
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return;
} else {
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
std::smatch match;
@@ -1205,10 +1336,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
if (start == std::string::npos) {
start = match.position(0);
}
// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
auto [tok_start, tok_end] = tok_pos;
if (tok_end <= start) {
continue;
}
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
size_t piece_len = tok_end - piece_start;
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
llama_grammar_accept_token(grammar, tok, tok_piece);
}
auto constrained_str = grammar.trigger_buffer.substr(start);
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
grammar.trigger_buffer_positions.clear();
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
return;
}
@@ -1228,7 +1372,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
}
llama_grammar_accept_str(grammar, piece);
llama_grammar_accept_token(grammar, token, piece);
}
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
@@ -1246,6 +1390,61 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
}
}
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
stacks_new.reserve(grammar.stacks.size());
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
continue;
}
const llama_grammar_element * pos = stack.back();
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
if (llama_grammar_match_token(pos, token)) {
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
new_stack.push_back(pos + 1);
}
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
}
} else {
llama_grammar_stacks current_stacks = {stack};
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks next_stacks;
for (const auto & cur_stack : current_stacks) {
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
}
current_stacks = std::move(next_stacks);
if (current_stacks.empty()) {
break;
}
}
for (auto & surviving_stack : current_stacks) {
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
stacks_new.emplace_back(surviving_stack);
}
}
}
}
grammar.stacks = std::move(stacks_new);
grammar.partial_utf8 = decoded.second;
if (grammar.stacks.empty()) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
}
}
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
try {

View File

@@ -47,11 +47,17 @@ enum llama_gretype {
// any character (.)
LLAMA_GRETYPE_CHAR_ANY = 7,
// terminal element: token (<[token-id]>)
LLAMA_GRETYPE_TOKEN = 8,
// inverse token (!<[token-id]>)
LLAMA_GRETYPE_TOKEN_NOT = 9,
};
typedef struct llama_grammar_element {
enum llama_gretype type;
uint32_t value; // Unicode code point or rule ID
uint32_t value; // Unicode code point, rule ID, or token ID
} llama_grammar_element;
struct llama_partial_utf8 {
@@ -63,6 +69,7 @@ struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
llama_token id;
};
using llama_grammar_rule = std::vector< llama_grammar_element>;
@@ -88,10 +95,13 @@ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_candidates & candidates);
struct llama_grammar_parser {
const llama_vocab * vocab;
std::map<std::string, uint32_t> symbol_ids;
llama_grammar_rules rules;
llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
llama_grammar_stack c_rules() const;
uint32_t get_symbol_id(const char * src, size_t len);
@@ -123,6 +133,9 @@ struct llama_grammar_trigger_pattern {
};
struct llama_grammar {
// maintain a list of llama_tokens and their positions in the trigger_buffer
using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
const ollama_vocab * o_vocab;
@@ -139,6 +152,7 @@ struct llama_grammar {
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<llama_grammar_trigger_pattern>
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
@@ -185,3 +199,8 @@ void llama_grammar_accept_impl(
void llama_grammar_accept_str(
struct llama_grammar & grammar,
const std::string & piece);
void llama_grammar_accept_token(
struct llama_grammar & grammar,
llama_token token,
const std::string & piece);

View File

@@ -71,11 +71,14 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(f_attn_temp_scale != 0.0f);
GGML_ASSERT(n_attn_temp_floor_scale != 0);
std::vector<float> attn_scale_data(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const float pos = ubatch->pos[i];
attn_scale_data[i] = std::log(
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
) * f_attn_temp_scale + 1.0;
}
@@ -251,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= s_copy->ne[0] == mctx->get_n_rs();
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
res &= head == mctx->get_head();
res &= rs_z == mctx->get_rs_z();
return res;
}
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
@@ -382,7 +403,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
return res;
}
@@ -413,10 +434,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
return res;
}
@@ -449,7 +470,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int i = n_tokens; i < n_tokens; ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
@@ -458,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_attn->set_input(ubatch);
inp_rs->set_input(ubatch);
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
res &= inp_rs->head == mctx->get_recr()->get_head();
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
return res;
}
//
@@ -810,9 +869,6 @@ ggml_tensor * llm_graph_context::build_ffn(
GGML_ABORT("fatal error");
}
//expand here so that we can fuse ffn gate
ggml_build_forward_expand(gf, cur);
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il);
@@ -973,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// mask out the other groups
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
@@ -1089,13 +1145,19 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_moe_relu", il);
} break;
case LLM_FFN_RELU_SQR:
if (gate_exps) {
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {
cur = ggml_relu(ctx0, cur);
cur = ggml_sqr(ctx0, cur);
cb(cur, "ffn_moe_relu_sqr", il);
} break;
default:
GGML_ABORT("fatal error");
}
//expand here so that we can fuse ffn gate
ggml_build_forward_expand(gf, cur);
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
@@ -1206,7 +1268,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
}
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
auto & cur = inp->attn_scale;
@@ -1473,13 +1535,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1561,7 +1623,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1704,7 +1766,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1770,7 +1832,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1784,7 +1846,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1844,6 +1906,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
inp->head = mctx_cur->get_head();
inp->rs_z = mctx_cur->get_rs_z();
return inp;
}
@@ -1912,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}

View File

@@ -132,8 +132,8 @@ public:
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -142,6 +142,7 @@ public:
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
const float f_attn_temp_offset;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
@@ -224,6 +225,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * s_copy; // I32 [n_rs]
// views of s_copy, computed once per graph
@@ -232,6 +235,10 @@ public:
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
const llama_memory_recurrent_context * mctx;
// used in view offsets, need to match for valid graph reuse
uint32_t head;
int32_t rs_z;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -364,22 +371,28 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_cparams cparams;
const llama_memory_hybrid_context * mctx;
};

Some files were not shown because too many files have changed in this diff Show More