Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
31d04eb795 model: add a test for model forward pass during implementation
Adds a new test file to verify model forward pass behavior through
JSON-specified test cases. The framework loads model files (.gguf) and their
corresponding test specifications to validate expected outputs using greedy
sampling.
2025-02-18 14:21:10 -08:00
326 changed files with 13861 additions and 35042 deletions

View File

@@ -111,13 +111,13 @@ jobs:
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'CUDA 12' preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe
cuda-version: '12.8' cuda-version: '12.4'
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'ROCm 6' preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
rocm-version: '6.2' rocm-version: '6.1'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@@ -160,10 +160,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:

View File

@@ -78,10 +78,10 @@ jobs:
include: include:
- preset: CPU - preset: CPU
- preset: CUDA - preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80' flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm - preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010' flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows runs-on: windows
steps: steps:
@@ -102,7 +102,7 @@ jobs:
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.8", "nvcc_11.8", "cublas_11.8", "cublas_dev_11.8")) -NoNewWindow -Wait
} }
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
@@ -140,13 +140,6 @@ jobs:
env: env:
CMAKE_GENERATOR: Ninja CMAKE_GENERATOR: Ninja
go_mod_tidy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: check that 'go mod tidy' is clean
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
test: test:
strategy: strategy:
matrix: matrix:
@@ -154,82 +147,15 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
CGO_ENABLED: '1' CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps: steps:
- name: checkout - uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/setup-go@v5
- name: cache restore
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with: with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
restore-keys: |
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
- name: Setup Go
uses: actions/setup-go@v5
with:
# The caching strategy of setup-go is less than ideal, and wastes
# time by not saving artifacts due to small failures like the linter
# complaining, etc. This means subsequent have to rebuild their world
# again until all checks pass. For instance, if you mispell a word,
# you're punished until you fix it. This is more hostile than
# helpful.
cache: false
go-version-file: go.mod go-version-file: go.mod
# It is tempting to run this in a platform independent way, but the past
# shows this codebase will see introductions of platform specific code
# generation, and so we need to check this per platform to ensure we
# don't abuse go generate on specific platforms.
- name: check that 'go generate' is clean
if: always()
run: |
go generate ./...
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
- name: go test
if: always()
run: go test -count=1 -benchtime=1x ./...
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6 - uses: golangci/golangci-lint-action@v6
with: with:
args: --timeout 10m0s -v args: --timeout 10m0s -v
- run: go test ./...
- name: cache save
# Always save the cache, even if the job fails. The artifacts produced
# during the building of test binaries are not all for naught. They can
# be used to speed up subsequent runs.
if: always()
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
patches: patches:
runs-on: ubuntu-latest runs-on: ubuntu-latest

5
.gitignore vendored
View File

@@ -5,6 +5,7 @@
.swp .swp
dist dist
build build
ollama
.cache .cache
*.exe *.exe
.idea .idea
@@ -13,4 +14,6 @@ test_data
__debug_bin* __debug_bin*
llama/build llama/build
llama/vendor llama/vendor
/ollama model/testdata/models/*
!model/testdata/models/*.md
!model/testdata/models/*.json

View File

@@ -6,6 +6,8 @@ linters:
- bidichk - bidichk
- bodyclose - bodyclose
- containedctx - containedctx
- contextcheck
- errcheck
- gocheckcompilerdirectives - gocheckcompilerdirectives
- gofmt - gofmt
- gofumpt - gofumpt
@@ -21,11 +23,10 @@ linters:
- staticcheck - staticcheck
- tenv - tenv
- unconvert - unconvert
- unused
- usestdlibvars
- wastedassign - wastedassign
- whitespace - whitespace
disable:
- usestdlibvars
- errcheck
linters-settings: linters-settings:
staticcheck: staticcheck:
checks: checks:
@@ -38,4 +39,5 @@ severity:
- gofmt - gofmt
- goimports - goimports
- intrange - intrange
- usestdlibvars
severity: info severity: info

View File

@@ -23,9 +23,8 @@ set(GGML_SCHED_MAX_COPIES 4)
set(GGML_LLAMAFILE ON) set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128) set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON) set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64") if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+")) OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
set(GGML_CPU_ALL_VARIANTS ON) set(GGML_CPU_ALL_VARIANTS ON)
endif() endif()
@@ -86,9 +85,9 @@ if(CMAKE_CUDA_COMPILER)
) )
endif() endif()
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$" set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
CACHE STRING CACHE STRING
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"." "Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
) )
check_language(HIP) check_language(HIP)
@@ -97,7 +96,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED) find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS) if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$") list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif() endif()
@@ -106,11 +105,9 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32) if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
endif() endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES

View File

@@ -21,14 +21,14 @@
"name": "CUDA 11", "name": "CUDA 11",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86" "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86"
} }
}, },
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120" "CMAKE_CUDA_ARCHITECTURES": "60;61;62;70;72;75;80;86;87;89;90;90a"
} }
}, },
{ {
@@ -56,7 +56,7 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
} }
} }
], ],

View File

@@ -6,6 +6,8 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally. See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
## Pull requests
### Ideal issues ### Ideal issues
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error. * [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
@@ -24,64 +26,11 @@ See the [development documentation](./docs/development.md) for instructions on h
* Changes that add significant friction to the user experience * Changes that add significant friction to the user experience
* Changes that create a large future maintenance burden for maintainers and contributors * Changes that create a large future maintenance burden for maintainers and contributors
## Proposing a (non-trivial) change ### Best practices
> By "non-trivial", we mean a change that is not a bug fix or small * Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
> documentation update. If you are unsure, please ask us on our [Discord * Tests: please add test coverage to changes where possible.
> server](https://discord.gg/ollama). * Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
get feedback from the maintainers. This helps us understand the context of the
change and how it fits into Ollama's roadmap and prevents us from duplicating
work or you from spending time on a change that we may not be able to accept.
Tips for proposals:
* Explain the problem you are trying to solve, not what you are trying to do.
* Explain why the change is important.
* Explain how the change will be used.
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the change were accepted.
## Pull requests
**Commit messages**
The title should look like:
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known
file in the root directory may use the file name.
The short description should start with a lowercase letter and be a
continuation of the sentence:
"This changes Ollama to..."
Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clairity on good commit messages, and bad
Bad Examples:
feat: add more emoji
fix: was not using famous web framework
chore: generify code
**Tests**
Please include tests. Strive to test behavior, not implementation.
**New dependencies**
Dependencies should be added sparingly. If you are adding a new dependency,
please explain why it is necessary and what other ways you attempted that
did not work without it.
## Need help? ## Need help?

View File

@@ -2,24 +2,22 @@
ARG FLAVOR=${TARGETARCH} ARG FLAVOR=${TARGETARCH}
ARG ROCMVERSION=6.3.3 ARG ROCMVERSION=6.1.2
ARG JETPACK5VERSION=r35.4.1 ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0 ARG JETPACK6VERSION=r36.2.0
ARG CMAKEVERSION=3.31.2 ARG CMAKEVERSION=3.31.2
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCMVERSION}-complete AS base-amd64
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \
RUN yum install -y yum-utils \ && yum install -y yum-utils devtoolset-10-gcc devtoolset-10-gcc-c++ \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ && curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64 FROM --platform=linux/arm64 rockylinux:8 AS base-arm64
# install epel-release for ccache # install epel-release for ccache
RUN yum install -y yum-utils epel-release \ RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache \ && yum install -y clang ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++ ENV CC=clang CXX=clang++
@@ -31,8 +29,9 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ # amd64 uses gcc which requires devtoolset-11 for AVX extensions while arm64 uses clang
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH RUN if [ "$(uname -m)" = "x86_64" ]; then yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++; fi
ENV PATH=/opt/rh/devtoolset-11/root/usr/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --parallel --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \
@@ -40,7 +39,7 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS cuda-11 FROM base AS cuda-11
ARG CUDA11VERSION=11.3 ARG CUDA11VERSION=11.3
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN yum install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \ cmake --preset 'CUDA 11' \
@@ -48,8 +47,8 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.4
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN yum install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \ cmake --preset 'CUDA 12' \
@@ -57,7 +56,6 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \ cmake --preset 'ROCm 6' \
&& cmake --build --parallel --preset 'ROCm 6' \ && cmake --build --parallel --preset 'ROCm 6' \
@@ -86,11 +84,10 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel 8
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama ARG GOVERSION=1.23.4
COPY go.mod go.sum . RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download WORKDIR /go/src/github.com/ollama/ollama
COPY . . COPY . .
ARG GOFLAGS="'-ldflags=-w -s'" ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1
@@ -104,10 +101,10 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5 COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6 COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
FROM scratch AS rocm FROM --platform=linux/arm64 scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive FROM ${FLAVOR} AS archive

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c
.PHONY: help .PHONY: help
help: help:

View File

@@ -1,5 +1,5 @@
<div align="center"> <div align="center">
  <a href="https://ollama.com">   <a href="https://ollama.com" />
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7"> <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a> </a>
</div> </div>
@@ -54,11 +54,6 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- | | ------------------ | ---------- | ----- | -------------------------------- |
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | | Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
@@ -69,7 +64,10 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | | Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | | Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` | | Phi 4 | 14B | 9.1GB | `ollama run phi4` |
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` | | Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Mistral | 7B | 4.1GB | `ollama run mistral` | | Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | | Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
@@ -77,7 +75,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` | | Solar | 10.7B | 6.1GB | `ollama run solar` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -277,7 +275,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui) - [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama) - [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
@@ -285,13 +282,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle) - [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui) - [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac) - [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI) - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica) - [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd) - [chatd](https://github.com/BruceMacD/chatd)
@@ -325,7 +321,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama) - [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models) - [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG) - [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
@@ -348,7 +343,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery) - [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models. - [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding - [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support) - [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -387,17 +382,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI) - [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models) - [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
### Cloud ### Cloud
@@ -437,14 +421,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage) - [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more. - [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama - [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro ### Apple Vision Pro
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
### Database ### Database
@@ -518,18 +498,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs) - [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig) - [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
### Mobile ### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Extensions & Plugins ### Extensions & Plugins
@@ -575,14 +550,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language - [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
### Supported backends ### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production. - [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -10,7 +10,7 @@
// repository]. // repository].
// //
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md // [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples // [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
package api package api
import ( import (
@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader var buf *bytes.Buffer
if data != nil { if data != nil {
bts, err := json.Marshal(data) bts, err := json.Marshal(data)
if err != nil { if err != nil {

View File

@@ -1,13 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing" "testing"
) )
@@ -50,206 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
}) })
} }
} }
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -10,9 +10,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
) )
// StatusError is an error with an HTTP status code and message. // StatusError is an error with an HTTP status code and message.
@@ -82,7 +79,7 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]any `json:"options"` Options map[string]interface{} `json:"options"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@@ -107,7 +104,7 @@ type ChatRequest struct {
Tools `json:"tools,omitempty"` Tools `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]interface{} `json:"options"`
} }
type Tools []Tool type Tools []Tool
@@ -163,65 +160,19 @@ func (t *ToolCallFunctionArguments) String() string {
type Tool struct { type Tool struct {
Type string `json:"type"` Type string `json:"type"`
Items any `json:"items,omitempty"`
Function ToolFunction `json:"function"` Function ToolFunction `json:"function"`
} }
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct { type ToolFunction struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Parameters struct { Parameters struct {
Type string `json:"type"` Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]struct { Properties map[string]struct {
Type PropertyType `json:"type"` Type string `json:"type"`
Items any `json:"items,omitempty"` Description string `json:"description"`
Description string `json:"description"` Enum []string `json:"enum,omitempty"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"` } `json:"properties"`
} `json:"parameters"` } `json:"parameters"`
} }
@@ -307,7 +258,7 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"` Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]interface{} `json:"options"`
} }
// EmbedResponse is the response from [Client.Embed]. // EmbedResponse is the response from [Client.Embed].
@@ -333,7 +284,7 @@ type EmbeddingRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]interface{} `json:"options"`
} }
// EmbeddingResponse is the response from [Client.Embeddings]. // EmbeddingResponse is the response from [Client.Embeddings].
@@ -379,7 +330,7 @@ type ShowRequest struct {
Template string `json:"template"` Template string `json:"template"`
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
Options map[string]any `json:"options"` Options map[string]interface{} `json:"options"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
@@ -387,18 +338,16 @@ type ShowRequest struct {
// ShowResponse is the response returned from [Client.Show]. // ShowResponse is the response returned from [Client.Show].
type ShowResponse struct { type ShowResponse struct {
License string `json:"license,omitempty"` License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"` Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"` Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
} }
// CopyRequest is the request passed to [Client.Copy]. // CopyRequest is the request passed to [Client.Copy].
@@ -410,9 +359,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull]. // PullRequest is the request passed to [Client.Pull].
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored Username string `json:"username"`
Password string `json:"password"` // Deprecated: ignored Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead
@@ -516,13 +465,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"` QuantizationLevel string `json:"quantization_level"`
} }
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() { func (m *Metrics) Summary() {
if m.TotalDuration > 0 { if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
@@ -551,7 +493,7 @@ func (m *Metrics) Summary() {
} }
} }
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@@ -608,12 +550,12 @@ func (opts *Options) FromMap(m map[string]any) error {
} }
field.SetString(val) field.SetString(val)
case reflect.Slice: case reflect.Slice:
// JSON unmarshals to []any, not []string // JSON unmarshals to []interface{}, not []string
val, ok := val.([]any) val, ok := val.([]interface{})
if !ok { if !ok {
return fmt.Errorf("option %q must be of type array", key) return fmt.Errorf("option %q must be of type array", key)
} }
// convert []any to []string // convert []interface{} to []string
slice := make([]string, len(val)) slice := make([]string, len(val))
for i, item := range val { for i, item := range val {
str, ok := item.(string) str, ok := item.(string)
@@ -667,7 +609,7 @@ func DefaultOptions() Options {
Runner: Runner{ Runner: Runner{
// options set when the model is loaded // options set when the model is loaded
NumCtx: int(envconfig.ContextLength()), NumCtx: 2048,
NumBatch: 512, NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide NumThread: 0, // let the runtime decide
@@ -720,7 +662,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
} }
// FormatParams converts specified parameter options to their correct types // FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]any, error) { func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{} opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@@ -734,7 +676,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
} }
} }
out := make(map[string]any) out := make(map[string]interface{})
// iterate params and set values based on json struct tags // iterate params and set values based on json struct tags
for key, vals := range params { for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok { if opt, ok := jsonOpts[key]; !ok {

View File

@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
var oMap map[string]any var oMap map[string]interface{}
err := json.Unmarshal([]byte(test.req), &oMap) err := json.Unmarshal([]byte(test.req), &oMap)
require.NoError(t, err) require.NoError(t, err)
opts := DefaultOptions() opts := DefaultOptions()
@@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
} }
} }
} }
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@@ -1,178 +0,0 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -18,8 +18,6 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -36,6 +34,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
@@ -257,7 +256,6 @@ func StopHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0]) return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
} }
return err
} }
return nil return nil
} }
@@ -268,7 +266,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts := runOptions{ opts := runOptions{
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{}, Options: map[string]interface{}{},
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
@@ -340,21 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) // TODO(jessegross): We should either find another way to know if this is
// a vision model or remove the logic. Also consider that other modalities will
// TODO: remove the projector info and vision info checks below, // need different behavior anyways.
// these are left in for backwards compatibility with older servers opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {
@@ -575,9 +562,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters") parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system") system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template") template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} { for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
if boolErr != nil { if boolErr != nil {
return errors.New("error retrieving flags") return errors.New("error retrieving flags")
} }
@@ -615,7 +601,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
} }
req := api.ShowRequest{Name: args[0], Verbose: verbose} req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(cmd.Context(), &req) resp, err := client.Show(cmd.Context(), &req)
if err != nil { if err != nil {
return err return err
@@ -638,10 +624,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
return showInfo(resp, verbose, os.Stdout) return showInfo(resp, os.Stdout)
} }
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { func showInfo(resp *api.ShowResponse, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) { tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header) fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w) table := tablewriter.NewWriter(w)
@@ -675,15 +661,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
return return
}) })
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
for _, capability := range resp.Capabilities {
rows = append(rows, []string{"", capability.String()})
}
return
})
}
if resp.ProjectorInfo != nil { if resp.ProjectorInfo != nil {
tableRender("Projector", func() (rows [][]string) { tableRender("Projector", func() (rows [][]string) {
arch := resp.ProjectorInfo["general.architecture"].(string) arch := resp.ProjectorInfo["general.architecture"].(string)
@@ -707,47 +684,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}) })
} }
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) { head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s)) scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) { for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -852,7 +788,7 @@ type runOptions struct {
Format string Format string
System string System string
Images []api.ImageData Images []api.ImageData
Options map[string]any Options map[string]interface{}
MultiModal bool MultiModal bool
KeepAlive *api.Duration KeepAlive *api.Duration
} }
@@ -1254,7 +1190,6 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model") showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model") showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model") showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{ runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]", Use: "run MODEL [PROMPT]",
@@ -1339,6 +1274,7 @@ func NewCLI() *cobra.Command {
runnerCmd := &cobra.Command{ runnerCmd := &cobra.Command{
Use: "runner", Use: "runner",
Short: llama.PrintSystemInfo(),
Hidden: true, Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return runner.Execute(os.Args[1:]) return runner.Execute(os.Args[1:])
@@ -1381,6 +1317,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"], envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"], envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"], envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],

View File

@@ -10,13 +10,11 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
) )
func TestShowInfo(t *testing.T) { func TestShowInfo(t *testing.T) {
@@ -28,7 +26,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -58,7 +56,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -69,60 +67,6 @@ func TestShowInfo(t *testing.T) {
embedding length 0 embedding length 0
quantization FP16 quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -144,7 +88,7 @@ func TestShowInfo(t *testing.T) {
stop you stop you
stop up stop up
temperature 99`, temperature 99`,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -181,7 +125,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0), "clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0), "clip.vision.projection_dim": float64(0),
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -214,7 +158,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey! Ahoy, matey!
Weigh anchor! Weigh anchor!
`, `,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -243,7 +187,7 @@ Weigh anchor!
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: license, License: license,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -261,34 +205,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
t.Run("capabilities", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
}, false, &b); err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture test \n" +
" parameters 7B \n" +
" quantization FP16 \n" +
"\n" +
" Capabilities\n" +
" vision \n" +
" tools \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
} }
func TestDeleteHandler(t *testing.T) { func TestDeleteHandler(t *testing.T) {
@@ -574,96 +490,6 @@ func TestPushHandler(t *testing.T) {
} }
} }
func TestListHandler(t *testing.T) {
tests := []struct {
name string
args []string
serverResponse []api.ListModelResponse
expectedError string
expectedOutput string
}{
{
name: "list all models",
args: []string{},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
"model2 sha256:def45 2.0 KB 2 days ago \n",
},
{
name: "filter models by prefix",
args: []string{"model1"},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
},
{
name: "server error",
args: []string{},
expectedError: "server error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
http.Error(w, "not found", http.StatusNotFound)
return
}
if tt.expectedError != "" {
http.Error(w, tt.expectedError, http.StatusInternalServerError)
return
}
response := api.ListResponse{Models: tt.serverResponse}
if err := json.NewEncoder(w).Encode(response); err != nil {
t.Fatal(err)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := ListHandler(cmd, tt.args)
// Restore stdout and get output
w.Close()
os.Stdout = oldStdout
output, _ := io.ReadAll(r)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
func TestCreateHandler(t *testing.T) { func TestCreateHandler(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -790,132 +616,3 @@ func TestCreateHandler(t *testing.T) {
}) })
} }
} }
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
) )
type MultilineState int type MultilineState int
@@ -196,10 +195,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
return err return err
} }
continue continue
@@ -348,7 +343,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
_ = showInfo(resp, false, os.Stderr) _ = showInfo(resp, os.Stderr)
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")
@@ -460,16 +455,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{ req := &api.CreateRequest{
Model: name, Name: name,
From: cmp.Or(parentModel, opts.Model), From: cmp.Or(opts.ParentModel, opts.Model),
} }
if opts.System != "" { if opts.System != "" {

View File

@@ -13,13 +13,8 @@ import (
) )
type ModelParameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
} }
type AdapterParameters struct { type AdapterParameters struct {
@@ -182,18 +177,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
var conv ModelConverter var conv ModelConverter
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
conv = &mixtralModel{} conv = &mixtralModel{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
conv = &gemmaModel{} conv = &gemmaModel{}
case "Gemma2ForCausalLM": case "Gemma2ForCausalLM":
conv = &gemma2Model{} conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM": case "Phi3ForCausalLM":
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":
@@ -203,7 +194,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM": case "CohereForCausalLM":
conv = &commandrModel{} conv = &commandrModel{}
default: default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) return errors.New("unsupported architecture")
} }
if err := json.Unmarshal(bts, conv); err != nil { if err := json.Unmarshal(bts, conv); err != nil {
@@ -222,14 +213,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
} }
vocabSize := int(p.VocabSize) vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch { switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens): case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {

View File

@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") { if strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }

View File

@@ -1,142 +0,0 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
gemmaModel
Architecture string
TextModel struct {
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
kv["gemma3.block_count"] = numBlocks
var (
numHeads uint32
numKVHeads uint32
)
switch numBlocks {
case gemma4BLayerCount:
numHeads = 8
numKVHeads = 4
case gemma12BLayerCount:
numHeads = 16
numKVHeads = 8
case gemma27BLayerCount:
numHeads = 32
numKVHeads = 16
default:
numHeads = p.NumAttentionHeads
numKVHeads = p.NumKeyValueHeads
}
kv["gemma3.attention.head_count"] = numHeads
kv["gemma3.attention.head_count_kv"] = numKVHeads
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
}
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv
}
func (p *gemma3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"vision_model.vision_model.embeddings", "v",
"vision_model.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@@ -1,190 +0,0 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3Model struct {
ModelParameters
ImageTokenIndex uint32 `json:"image_token_index"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
TextModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
if p.ProjectorHiddenAct != "" {
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
}
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3Model) Replacements() []string {
return []string{
"language_model.model.norm", "output_norm",
"language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -62,7 +62,10 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
Pattern string Pattern string
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error) Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{ }{
{"*.safetensors", parseSafetensors}, {"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch}, {"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch}, {"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch}, {"consolidated.*.pth", parseTorch},

View File

@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2) var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_sentencepiece_model_proto_goTypes = []any{ var file_sentencepiece_model_proto_goTypes = []interface{}{
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType (TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type (ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec (*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
return return
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TrainerSpec); i { switch v := v.(*TrainerSpec); i {
case 0: case 0:
return &v.state return &v.state
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
return nil return nil
} }
} }
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NormalizerSpec); i { switch v := v.(*NormalizerSpec); i {
case 0: case 0:
return &v.state return &v.state
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
return nil return nil
} }
} }
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData); i { switch v := v.(*SelfTestData); i {
case 0: case 0:
return &v.state return &v.state
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
return nil return nil
} }
} }
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto); i { switch v := v.(*ModelProto); i {
case 0: case 0:
return &v.state return &v.state
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
return nil return nil
} }
} }
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData_Sample); i { switch v := v.(*SelfTestData_Sample); i {
case 0: case 0:
return &v.state return &v.state
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
return nil return nil
} }
} }
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any { file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto_SentencePiece); i { switch v := v.(*ModelProto_SentencePiece); i {
case 0: case 0:
return &v.state return &v.state

View File

@@ -6,9 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog"
"os" "os"
"reflect"
"slices" "slices"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@@ -17,8 +15,6 @@ import (
) )
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
slog.Debug("using spm vocabulary")
ast, err := parseAdditionalSpecialTokens(fsys) ast, err := parseAdditionalSpecialTokens(fsys)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -47,19 +43,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
v.Types = append(v.Types, int32(t)) v.Types = append(v.Types, int32(t))
default: default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
if slices.Contains(ast, piece.GetPiece()) {
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
} }
for _, t := range ast {
if t.Content == piece.GetPiece() {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
break
}
}
v.Types = append(v.Types, tt) v.Types = append(v.Types, tt)
} }
} }
@@ -91,16 +78,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return cmp.Compare(i.id, j.id) return cmp.Compare(i.id, j.id)
}) })
for _, t := range ts { n := len(v.Tokens)
if t.id < len(v.Tokens) { for i, t := range ts {
if v.Tokens[t.id] == t.content { if t.id != i+n {
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id) return nil, fmt.Errorf("invalid token id: %d", t.id)
continue
}
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
}
if t.id != len(v.Tokens) {
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
} }
v.Tokens = append(v.Tokens, t.content) v.Tokens = append(v.Tokens, t.content)
@@ -111,15 +92,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return &v, nil return &v, nil
} }
type specialToken struct { func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
Content string `json:"content"`
Lstrip bool `json:"lstrip"`
Normalized bool `json:"normalized"`
Rstrip bool `json:"rstrip"`
SingleWord bool `json:"single_word"`
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
f, err := fsys.Open("special_tokens_map.json") f, err := fsys.Open("special_tokens_map.json")
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return nil, nil return nil, nil
@@ -129,43 +102,12 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
defer f.Close() defer f.Close()
var m struct { var m struct {
AdditionalSpecialTokens any `json:"additional_special_tokens"` AdditionalSpecialTokens []string `json:"additional_special_tokens"`
} }
if err := json.NewDecoder(f).Decode(&m); err != nil { if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err return nil, err
} }
var ast []specialToken return m.AdditionalSpecialTokens, nil
switch st := m.AdditionalSpecialTokens.(type) {
case []string:
for _, s := range st {
ast = append(ast, specialToken{Content: s})
}
case []any:
for _, s := range st {
// marshal and unmarshal the object to get the special token
tMap := s.(map[string]any)
data, err := json.Marshal(tMap)
if err != nil {
return nil, err
}
var token specialToken
err = json.Unmarshal(data, &token)
if err != nil {
return nil, err
}
ast = append(ast, token)
}
default:
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
}
slog.Debug("spm tokenizer", "additional tokens", ast)
return ast, nil
} }

View File

@@ -12,7 +12,7 @@ func IsNUMA() bool {
// numa support in llama.cpp is linux only // numa support in llama.cpp is linux only
return false return false
} }
ids := map[string]any{} ids := map[string]interface{}{}
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
for _, packageId := range packageIds { for _, packageId := range packageIds {
id, err := os.ReadFile(packageId) id, err := os.ReadFile(packageId)

View File

@@ -57,8 +57,7 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
} }
} }
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
return "v11" return "v11"
} }
return "v12" return "v12"

View File

@@ -111,7 +111,6 @@ func GetCPUDetails() ([]CPU, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close()
return linuxCPUDetails(file) return linuxCPUDetails(file)
} }
@@ -169,11 +168,13 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
for id, s := range socketByID { for id, s := range socketByID {
s.CoreCount = len(coreBySocket[id]) s.CoreCount = len(coreBySocket[id])
s.ThreadCount = 0 s.ThreadCount = 0
for _, tc := range threadsByCoreBySocket[id] {
s.ThreadCount += tc
}
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons? // This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
efficiencyCoreCount := 0 efficiencyCoreCount := 0
for _, threads := range threadsByCoreBySocket[id] { for _, threads := range threadsByCoreBySocket[id] {
s.ThreadCount += threads
if threads == 1 { if threads == 1 {
efficiencyCoreCount++ efficiencyCoreCount++
} }

View File

@@ -558,10 +558,6 @@ Final response:
{ {
"model": "llama3.2", "model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true, "done": true,
"total_duration": 4883583458, "total_duration": 4883583458,
"load_duration": 1334875, "load_duration": 1334875,
@@ -1217,7 +1213,7 @@ Show information about a model including details, modelfile, template, parameter
```shell ```shell
curl http://localhost:11434/api/show -d '{ curl http://localhost:11434/api/show -d '{
"model": "llava" "model": "llama3.2"
}' }'
``` ```
@@ -1260,11 +1256,7 @@ curl http://localhost:11434/api/show -d '{
"tokenizer.ggml.pre": "llama-bpe", "tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": [], // populates if `verbose=true` "tokenizer.ggml.token_type": [], // populates if `verbose=true`
"tokenizer.ggml.tokens": [] // populates if `verbose=true` "tokenizer.ggml.tokens": [] // populates if `verbose=true`
}, }
"capabilities": [
"completion",
"vision"
],
} }
``` ```

View File

@@ -1,59 +0,0 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -41,11 +41,20 @@ Install prerequisites:
- [CMake](https://cmake.org/download/) - [CMake](https://cmake.org/download/)
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload - [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
- (Optional) AMD GPU support - (Optional) AMD GPU support
- [ROCm](https://rocm.docs.amd.com/en/latest/) - [ROCm](https://rocm.github.io/install.html)
- [Ninja](https://github.com/ninja-build/ninja/releases) - [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support - (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
> [!IMPORTANT]
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
> [!IMPORTANT]
> CUDA is only compatible with Visual Studio CMake generators.
Then, configure and build the project: Then, configure and build the project:
```shell ```shell
@@ -53,14 +62,6 @@ cmake -B build
cmake --build build --config Release cmake --build build --config Release
``` ```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
> cmake --build build --config Release
> ```
Lastly, run Ollama: Lastly, run Ollama:
```shell ```shell
@@ -69,7 +70,7 @@ go run . serve
## Windows (ARM) ## Windows (ARM)
Windows ARM does not support additional acceleration libraries at this time. Do not use cmake, simply `go run` or `go build`. Windows ARM does not support additional acceleration libraries at this time.
## Linux ## Linux
@@ -118,35 +119,6 @@ To run tests, use `go test`:
go test ./... go test ./...
``` ```
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or
> test failures resulting from your change(s), if any, locally, but CI will
> break.
>
> If you see failures in CI, you can either keep pushing changes to see if the
> CI build passes, or you can enable the "synctest" package locally to see the
> failures before pushing.
>
> To enable the "synctest" package for testing, run the following command:
>
> ```shell
> GOEXPERIMENT=synctest go test ./...
> ```
>
> If you wish to enable synctest for all go commands, you can set the
> `GOEXPERIMENT` environment variable in your shell profile or by using:
>
> ```shell
> go env -w GOEXPERIMENT=synctest
> ```
>
> Which will enable the "synctest" package for all go commands without needing
> to set it for all shell sessions.
>
> The synctest package is not required for production builds.
## Library detection ## Library detection
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable: Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
@@ -156,4 +128,4 @@ Ollama looks for acceleration libraries in the following paths relative to the `
* `.` (macOS) * `.` (macOS)
* `build/lib/ollama` (for development) * `build/lib/ollama` (for development)
If the libraries are not found, Ollama will not run with any acceleration libraries. If the libraries are not found, Ollama will not run with any acceleration libraries.

View File

@@ -20,13 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens. By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
```shell
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
```
To change this when using `ollama run`, use `/set parameter`: To change this when using `ollama run`, use `/set parameter`:
@@ -193,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`. Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform. Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored? ## Where are models stored?

View File

@@ -75,7 +75,7 @@ RestartSec=3
Environment="PATH=$PATH" Environment="PATH=$PATH"
[Install] [Install]
WantedBy=multi-user.target WantedBy=default.target
``` ```
Then start the service: Then start the service:

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command: On **Linux** systems with systemd, the logs can be found with this command:
```shell ```shell
journalctl -u ollama --no-pager --follow --pager-end journalctl -u ollama --no-pager
``` ```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container: When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
@@ -26,6 +26,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log` - `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
@@ -68,9 +69,9 @@ If you run into problems on Linux and want to install an older version, or you'd
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
``` ```
## Linux docker ## Linux tmp noexec
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration. If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## NVIDIA GPU Discovery ## NVIDIA GPU Discovery
@@ -99,6 +100,8 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44` When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure. If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems - `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported - `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported

View File

@@ -62,6 +62,7 @@ the explorer window by hitting `<Ctrl>+R` and type in:
- *upgrade.log* contains log output for upgrades - *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration - `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall ## Uninstall
@@ -80,11 +81,9 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI `ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download and GPU library dependencies for Nvidia and AMD. This allows for embedding
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the Ollama in existing applications, or running it as a system service via `ollama
same directory. This allows for embedding Ollama in existing applications, or serve` with tools such as [NSSM](https://nssm.cc/).
running it as a system service via `ollama serve` with tools such as
[NSSM](https://nssm.cc/).
> [!NOTE] > [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first. > If you are upgrading from a prior version, you should remove the old directories first.

View File

@@ -53,8 +53,8 @@ func Host() *url.URL {
} }
} }
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable. // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) { func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" { if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",") origins = strings.Split(s, ",")
} }
@@ -73,7 +73,6 @@ func AllowedOrigins() (origins []string) {
"file://*", "file://*",
"tauri://*", "tauri://*",
"vscode-webview://*", "vscode-webview://*",
"vscode-file://*",
) )
return origins return origins
@@ -168,8 +167,6 @@ var (
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine // Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
) )
func String(s string) func() string { func String(s string) func() string {
@@ -252,10 +249,9 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational // Informational

View File

@@ -69,7 +69,6 @@ func TestOrigins(t *testing.T) {
"file://*", "file://*",
"tauri://*", "tauri://*",
"vscode-webview://*", "vscode-webview://*",
"vscode-file://*",
}}, }},
{"http://10.0.0.1", []string{ {"http://10.0.0.1", []string{
"http://10.0.0.1", "http://10.0.0.1",
@@ -89,7 +88,6 @@ func TestOrigins(t *testing.T) {
"file://*", "file://*",
"tauri://*", "tauri://*",
"vscode-webview://*", "vscode-webview://*",
"vscode-file://*",
}}, }},
{"http://172.16.0.1,https://192.168.0.1", []string{ {"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1", "http://172.16.0.1",
@@ -110,7 +108,6 @@ func TestOrigins(t *testing.T) {
"file://*", "file://*",
"tauri://*", "tauri://*",
"vscode-webview://*", "vscode-webview://*",
"vscode-file://*",
}}, }},
{"http://totally.safe,http://definitely.legit", []string{ {"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe", "http://totally.safe",
@@ -131,14 +128,13 @@ func TestOrigins(t *testing.T) {
"file://*", "file://*",
"tauri://*", "tauri://*",
"vscode-webview://*", "vscode-webview://*",
"vscode-file://*",
}}, }},
} }
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) { t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value) t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" { if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
} }
}) })
@@ -276,19 +272,3 @@ func TestVar(t *testing.T) {
}) })
} }
} }
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 2048,
"4096": 4096,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", k)
if i := ContextLength(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -5,7 +5,7 @@ import (
"time" "time"
) )
func assertEqual(t *testing.T, a any, b any) { func assertEqual(t *testing.T, a interface{}, b interface{}) {
if a != b { if a != b {
t.Errorf("Assert failed, expected %v, got %v", b, a) t.Errorf("Assert failed, expected %v, got %v", b, a)
} }

View File

@@ -1,13 +0,0 @@
package fs
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}

View File

@@ -100,10 +100,6 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
return keyValue(kv, key, append(defaultValue, 0)...) return keyValue(kv, key, append(defaultValue, 0)...)
} }
func (kv KV) Bool(key string, defaultValue ...bool) bool {
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{}) r := keyValue(kv, key, &array{})
s := make([]string, r.size) s := make([]string, r.size)
@@ -124,23 +120,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"mistral3",
}, kv.Architecture())
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
@@ -227,26 +207,11 @@ func (t Tensor) block() (n int) {
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
switch t.Kind { switch t.Kind {
case case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
0, // F32
1, // F16
24, // I8
25, // I16
26, // I32
27, // I64
28, // F64
30: // BF16
return 1 return 1
case case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
2, // Q4_0
3, // Q4_1
6, // Q5_0
7, // Q5_1
8, // Q8_0
9, // Q8_1
20: // IQ4_NL
return 32 return 32
default: default: // All others
return 256 return 256
} }
} }
@@ -270,7 +235,7 @@ func (t Tensor) typeSize() uint64 {
case 8: // Q8_0 case 8: // Q8_0
return 2 + blockSize return 2 + blockSize
case 9: // Q8_1 case 9: // Q8_1
return 2 + 2 + blockSize return 4 + 4 + blockSize
case 10: // Q2_K case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2 return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K case 11: // Q3_K
@@ -282,7 +247,7 @@ func (t Tensor) typeSize() uint64 {
case 14: // Q6_K case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2 return blockSize/2 + blockSize/4 + blockSize/16 + 2
case 15: // Q8_K case 15: // Q8_K
return 4 + blockSize + 2*blockSize/16 return 2 + blockSize + 2*blockSize/16
case 16: // IQ2_XXS case 16: // IQ2_XXS
return 2 + 2*blockSize/8 return 2 + 2*blockSize/8
case 17: // IQ2_XS case 17: // IQ2_XS
@@ -311,8 +276,6 @@ func (t Tensor) typeSize() uint64 {
return 8 return 8
case 29: // IQ1_M case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32 return blockSize/8 + blockSize/16 + blockSize/32
case 30: // BF16
return 2
default: default:
return 0 return 0
} }
@@ -330,10 +293,6 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize() return t.parameters() * t.typeSize() / t.blockSize()
} }
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface { type container interface {
Name() string Name() string
Decode(io.ReadSeeker) (model, error) Decode(io.ReadSeeker) (model, error)
@@ -416,7 +375,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil }, offset, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV() headsKV := f.KV().HeadCountKV()
@@ -429,10 +388,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers() layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType) bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = make([]uint64, f.KV().BlockCount()) kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
switch f.KV().Architecture() { switch f.KV().Architecture() {
case "llama": case "llama":
@@ -466,14 +422,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama": case "mllama":
var visionTokens, tiles uint64 = 1601, 4 var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
for i := range kv { kv = headsKV *
if slices.Contains(crossAttentionLayers, uint32(i)) { (embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * (2* // sizeof(float16)
4 * // sizeof(float32) (f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
visionTokens * context +
tiles 4* // sizeof(float32)
} uint64(crossAttentionLayers.size)* // num cross attention layers
visionTokens*
tiles)
} }
fullOffload = max( fullOffload = max(
@@ -497,7 +455,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph // vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
case "gemma", "gemma2", "gemma3": case "gemma", "gemma2":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -509,20 +467,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
4*embeddingHeadsK*context*8+ 4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16, embedding*embeddingHeadsK*heads*9/16,
) )
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {
const gemma3GlobalCacheCount = 6
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
for i := range kv {
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
}
}
case "command-r": case "command-r":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
@@ -600,56 +544,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
return return
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported // SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool { func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)

View File

@@ -3,7 +3,6 @@ package ggml
import ( import (
"maps" "maps"
"slices" "slices"
"strconv"
"strings" "strings"
"testing" "testing"
@@ -158,55 +157,3 @@ func TestTensorLayers(t *testing.T) {
}) })
} }
} }
// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
func TestTensorTypes(t *testing.T) {
cases := []struct {
kind uint32
blockSize uint64
typeSize uint64
}{
{0, 1, 4},
{1, 1, 2},
{2, 32, 18},
{3, 32, 20},
{6, 32, 22},
{7, 32, 24},
{8, 32, 34},
{9, 32, 36},
{10, 256, 84},
{11, 256, 110},
{12, 256, 144},
{13, 256, 176},
{14, 256, 210},
{15, 256, 292},
{16, 256, 66},
{17, 256, 74},
{18, 256, 98},
{19, 256, 50},
{20, 32, 18},
{21, 256, 110},
{22, 256, 82},
{23, 256, 136},
{24, 1, 1},
{25, 1, 2},
{26, 1, 4},
{27, 1, 8},
{28, 1, 8},
{29, 256, 56},
{30, 1, 2},
}
for _, tt := range cases {
t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
tensor := Tensor{Kind: tt.kind}
if tensor.blockSize() != tt.blockSize {
t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
}
if tensor.typeSize() != tt.typeSize {
t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
}
})
}
}

19
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/ollama/ollama module github.com/ollama/ollama
go 1.24.0 go 1.23.4
require ( require (
github.com/containerd/console v1.0.3 github.com/containerd/console v1.0.3
@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4 github.com/x448/float16 v0.8.4
golang.org/x/sync v0.11.0 golang.org/x/sync v0.10.0
) )
require ( require (
@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0 gonum.org/v1/gonum v0.15.0
) )
require ( require (
@@ -44,7 +44,6 @@ require (
github.com/xtgo/set v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect
) )
@@ -70,12 +69,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.31.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.35.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.30.0 golang.org/x/sys v0.28.0
golang.org/x/term v0.29.0 golang.org/x/term v0.27.0
golang.org/x/text v0.22.0 golang.org/x/text v0.21.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

30
go.sum
View File

@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE= golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -309,8 +309,6 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
Model: "orca-mini", Model: "orca-mini",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
Prompt: "天空为什么是蓝色的?", Prompt: "天空为什么是蓝色的?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
// Workaround deepseek context shifting bug // Workaround deepseek context shifting bug
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
Model: "gemma2:2b", Model: "gemma2:2b",
Prompt: "Output some smily face emoji", Prompt: "Output some smily face emoji",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
Model: "orca-mini", Model: "orca-mini",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },

View File

@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },

View File

@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
Model: "llama2", Model: "llama2",
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
"num_ctx": 128, "num_ctx": 128,
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
Model: "llama2", Model: "llama2",
Prompt: "Write me a story with a ton of emojis?", Prompt: "Write me a story with a ton of emojis?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
"num_ctx": 128, "num_ctx": 128,

View File

@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
Model: "llava:7b", Model: "llava:7b",
Prompt: "what does the text in this image say?", Prompt: "what does the text in this image say?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
Model: "x/llama3.2-vision", Model: "x/llama3.2-vision",
Prompt: "what does the text in this image say?", Prompt: "what does the text in this image say?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -66,35 +66,6 @@ func TestIntegrationMllama(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
} }
func TestIntegrationSplitBatch(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -20,7 +20,7 @@ var (
Model: "orca-mini", Model: "orca-mini",
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -28,7 +28,7 @@ var (
Model: "orca-mini", Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },

View File

@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: "orca-mini", Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey", Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx embedCtx := ctx
var genwg sync.WaitGroup var genwg sync.WaitGroup
genwg.Add(1)
go func() { go func() {
genwg.Add(1)
defer genwg.Done() defer genwg.Done()
slog.Info("Starting generate request") slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second) DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{} counterMu := sync.Mutex{}
var embedwg sync.WaitGroup var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ { for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) { go func(i int) {
embedwg.Add(1)
defer embedwg.Done() defer embedwg.Done()
slog.Info("embed started", "id", i) slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{ embedReq := api.EmbeddingRequest{

View File

@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the color of dirt brown?", Prompt: "why is the color of dirt brown?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of independence day?", Prompt: "what is the origin of independence day?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the composition of air?", Prompt: "what is the composition of air?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{ Options: map[string]interface{}{
"seed": 42, "seed": 42,
"temperature": 0.0, "temperature": 0.0,
}, },

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
var ( var (
@@ -30,44 +29,22 @@ type Cache interface {
// cache implementation used. // cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor) Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management ** // ** cache management **
// Init sets up runtime parameters. // Init sets up runtime parameters
// backend: Used to allocate cache data storage and execute management operations (such as defrag) Init(backend ml.Backend, dtype ml.DType, capacity int32)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it // Close closes the cache and frees resources associated with it
Close() Close()
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory // entry in positions and seqs.
// without actually storing data in the cache. StartForward(ctx ml.Context, positions []int32, seqs []int) error
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex. // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
// //

View File

@@ -8,7 +8,6 @@ import (
"slices" "slices"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -20,13 +19,9 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size // The mask is of shape history size, batch size
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
Capacity int32
windowSize int32 windowSize int32
opts CausalOptions
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -44,12 +39,6 @@ type Causal struct {
// locations in the cache that are needed for this batch // locations in the cache that are needed for this batch
curCellRange cellRange curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata ** // ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences // for each possible location in the cache, stores the position and set of sequences
@@ -63,8 +52,8 @@ type Causal struct {
shiftFn shiftFn shiftFn shiftFn
backend ml.Backend backend ml.Backend
ctxs map[int]ml.Context cacheCtx ml.Context
keys, values map[int]ml.Tensor keys, values []ml.Tensor
} }
type cacheCell struct { type cacheCell struct {
@@ -78,129 +67,67 @@ type cellRange struct {
} }
func NewCausalCache(shift shiftFn) *Causal { func NewCausalCache(shift shiftFn) *Causal {
return &Causal{ return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func NewSWACache(windowSize int32, shift shiftFn) *Causal { func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{ return &Causal{windowSize: windowSize, shiftFn: shift}
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
cacheSize = maxSequences * capacity
} else {
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype c.DType = dtype
c.Capacity = capacity
c.cells = make([]cacheCell, capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
} c.cacheCtx = backend.NewContext()
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
} }
func (c *Causal) Close() { func (c *Causal) Close() {
for _, ctx := range c.ctxs { c.cacheCtx.Close()
ctx.Close()
}
} }
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(batch.Positions) c.curBatchSize = len(positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
var err error var err error
c.curMask, err = c.buildMask(ctx) c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range positions {
seq := seqs[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx, positions, seqs)
return err return err
} }
@@ -227,138 +154,39 @@ func (c *Causal) findStartLoc() (int, error) {
} }
} }
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells)) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
}
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
} }
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend // TODO(jessegross): This does not do padding, which is required for flash attention
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
(enabled && c.cells[j].pos > c.curPositions[i]) || c.cells[j].pos < positions[i]-c.windowSize {
c.cells[j].pos < c.curPositions[i]-c.windowSize { mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
} }
} }
// Mask out any padding tokens we added. For padding that we added to the cache history, this return ctx.FromFloatSlice(mask, len, c.curBatchSize)
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for i, key := range c.keys { for _, obj := range objs {
if key == nil { if obj == nil {
continue continue
} }
kHeadDim := key.Dim(0) srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
numKVHeads := key.Dim(1) dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length) ctx.Forward(srcView.Copy(ctx, dstView))
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
} }
} }
@@ -382,8 +210,7 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext() ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a // For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original // copy for each of k and v).
// k and v cache tensors - once per layer, not per move.
layers := 0 layers := 0
for _, key := range c.keys { for _, key := range c.keys {
if key == nil { if key == nil {
@@ -392,7 +219,7 @@ func (c *Causal) defrag() {
layers++ layers++
} }
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers) maxMoves := ctx.MaxTensors() / (6 * layers)
moves := 0 moves := 0
var pendingSrc, pendingDst, pendingLen int var pendingSrc, pendingDst, pendingLen int
@@ -411,7 +238,8 @@ func (c *Causal) defrag() {
pendingLen++ pendingLen++
break break
} else { } else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
} }
@@ -435,7 +263,8 @@ func (c *Causal) defrag() {
} }
if pendingLen > 0 { if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
@@ -464,107 +293,45 @@ func (c *Causal) defrag() {
} }
func (c *Causal) SetLayer(layer int) { func (c *Causal) SetLayer(layer int) {
c.curLayer = layer if layer >= len(c.keys) {
} c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
} }
c.curLayer = layer
} }
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer] key := c.keys[c.curLayer]
value := c.values[c.curLayer] value := c.values[c.curLayer]
kHeadDim := key.Dim(0) key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
numKVHeads := key.Dim(1) key.Dim(0), key.Stride(1),
rowSize := key.Stride(2) key.Dim(1), key.Stride(2),
cachedSize := c.curMask.Dim(0) c.curMask.Dim(0),
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
) )
if c.config.PermutedV { value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
vHeadDim := value.Dim(1) value.Dim(0), value.Stride(1),
elemSize := value.Stride(0) value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
value = value.View(ctx, elemSize*c.curCellRange.min, )
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask return key, value, c.curMask
} }
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(0) if c.curBatchSize != key.Dim(2) {
vHeadDim := value.Dim(0) panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
} }
if _, ok := c.ctxs[c.curLayer]; !ok { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer) c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
} }
if _, ok := c.keys[c.curLayer]; !ok { ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells)) ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
} }
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -590,35 +357,6 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
c.cellRanges[dstSeq] = seqRange c.cellRanges[dstSeq] = seqRange
} }
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
return true
}
seqRange, ok := c.cellRanges[seq]
if !ok {
return false
}
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
if last == -1 {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error { func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil { if c.shiftFn == nil {
return ErrNotSupported return ErrNotSupported
@@ -639,7 +377,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift, err := ctx.FromIntSlice(offsets, len(offsets))
if err != nil { if err != nil {
return err return err
} }
@@ -649,13 +387,9 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue continue
} }
kHeadDim := key.Dim(0) key = key.View(ctx, key.Stride(2)*seqRange.min,
numKVHeads := key.Dim(1) key.Dim(0), key.Stride(1),
rowSize := key.Stride(2) key.Dim(1), key.Stride(2),
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size, size,
) )
@@ -673,12 +407,6 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// TODO(jessegross): We should check to see if removing the middle of the sequence will
// cause the sliding window to encompass tokens that we no longer have. If so, then we
// should return an error, which will trigger the runner to evaluate the full history and
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
// results in use after free, so we don't do it for now.
var offset int32 var offset int32
if endIndex != math.MaxInt32 { if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex offset = beginIndex - endIndex
@@ -693,7 +421,8 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
} else { } else {
if c.cells[i].pos >= endIndex { if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
return errors.New("shifting cells shared by multiple sequences not supported") // TODO(jessegross): Need to be careful about data shared between sequences
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
} }
c.cells[i].pos += offset c.cells[i].pos += offset

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type testCase struct { type testCase struct {
@@ -25,7 +24,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -58,11 +57,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF32, 16)
tests := []testCase{ tests := []testCase{
{ {
name: "FirstBatch", name: "SlidingWindow",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4}, inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0}, seqs: []int{0, 0, 0, 0},
@@ -71,16 +70,6 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
}, },
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
} }
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
@@ -91,7 +80,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -126,7 +115,7 @@ func TestRemove(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -191,7 +180,7 @@ func TestDefrag(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -239,7 +228,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{ tests := []testCase{
{ {
@@ -280,7 +269,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) err := cache.StartForward(context, test.pos, test.seqs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -291,7 +280,9 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
context.Forward(out, mask).Compute(out, mask) context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
@@ -300,94 +291,27 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
} }
func TestCanResume(t *testing.T) { type testBackend struct{}
backend := &testBackend{}
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) func (b *testBackend) Config() ml.Config {
panic("not implemented")
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)")
}
if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)")
}
if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)")
}
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
// shift window by adding position 4
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
} }
type testBackend struct { func (b *testBackend) Get(name string) ml.Tensor {
ml.Backend panic("not implemented")
} }
func (b *testBackend) NewContext() ml.Context { func (b *testBackend) NewContext() ml.Context {
return &testContext{} return &testContext{}
} }
func (b *testBackend) NewContextSize(int) ml.Context { func (b *testBackend) SystemInfo() string {
return &testContext{} return "not implemented"
} }
type testContext struct { type testContext struct{}
ml.Context
}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
total := 0 total := 0
if len(shape) > 0 { if len(shape) > 0 {
@@ -400,12 +324,8 @@ func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
} }
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
@@ -424,24 +344,17 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil return out, nil
} }
func (c *testContext) Input() ml.Context { return c } func (c *testContext) Forward(ml.Tensor) {}
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil } func (c *testContext) MaxTensors() int {
func (c *testContext) MaxGraphNodes() int {
return 10 return 10
} }
func (c *testContext) Close() {} func (c *testContext) Close() {}
type testTensor struct { type testTensor struct {
ml.Tensor
dtype ml.DType dtype ml.DType
elementSize int elementSize int
data []float32 data []float32
@@ -469,22 +382,18 @@ func (t *testTensor) DType() ml.DType {
return t.dtype return t.dtype
} }
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 { func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data)) out := make([]float32, len(t.data))
copy(out, t.data) copy(out, t.data)
return out return out
} }
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data { for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i] out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -493,6 +402,58 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize offset /= t.elementSize
@@ -509,12 +470,40 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{} context := &testContext{}
view := context.Empty(t.dtype, s...).(*testTensor) view := context.Zeros(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)] view.data = t.data[offset : offset+len(view.data)]
return view return view
} }
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data) copy(t2.(*testTensor).data, t.data)
return nil return nil

View File

@@ -1,10 +1,7 @@
package kvcache package kvcache
import ( import (
"fmt"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Encoder cache stores K and V tensors that are position independent // Encoder cache stores K and V tensors that are position independent
@@ -14,9 +11,6 @@ import (
// //
// Not currently safe for multiple sequences // Not currently safe for multiple sequences
type EncoderCache struct { type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -27,11 +21,6 @@ type EncoderCache struct {
// anything will be stored) // anything will be stored)
curPos int32 curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata ** // ** cache metadata **
// was something stored in the cache? // was something stored in the cache?
@@ -41,65 +30,36 @@ type EncoderCache struct {
encoderPos int32 encoderPos int32
// ** cache data storage ** // ** cache data storage **
backend ml.Backend
ctxs map[int]ml.Context cacheCtx ml.Context
keys, values map[int]ml.Tensor keys, values []ml.Tensor
} }
func NewEncoderCache() *EncoderCache { func NewEncoderCache() *EncoderCache {
return &EncoderCache{ return &EncoderCache{}
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil { c.cacheCtx = backend.NewContext()
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.backend = backend
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
} }
func (c *EncoderCache) Close() { func (c *EncoderCache) Close() {
for _, ctx := range c.ctxs { c.cacheCtx.Close()
ctx.Close()
}
} }
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
// We work with the most recent image // The image is always in the first position
if len(batch.Multimodal) > 0 { c.curPos = positions[0]
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
c.curReserve = reserve
return nil return nil
} }
func (c *EncoderCache) SetLayer(layer int) { func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer c.curLayer = layer
} }
@@ -112,41 +72,22 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
} }
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
if !c.curReserve { c.encoderPos = c.curPos
c.encoderPos = c.curPos c.encoderCached = true
c.encoderCached = true
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
} }
if c.config.PermutedV { ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
value = value.Permute(ctx, 1, 2, 0, 3) ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
}
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer]),
value.Copy(ctx, c.values[c.curLayer]),
)
} }
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("encoder cache does not support multiple sequences") panic("encoder cache does not support multiple sequences")
} }
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
return true
}
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.encoderPos >= beginIndex && c.encoderPos < endIndex { if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
c.encoderCached = false c.encoderCached = false

View File

@@ -4,7 +4,6 @@ import (
"math" "math"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Wrapper cache is a container for multiple types of caches, // Wrapper cache is a container for multiple types of caches,
@@ -23,15 +22,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
} }
} }
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Init(backend, dtype, maxSequences, capacity, maxBatch) cache.Init(backend, dtype, capacity)
}
}
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
} }
} }
@@ -41,14 +34,14 @@ func (c *WrapperCache) Close() {
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, batch, reserve) err := cache.StartForward(ctx, positions, seqs)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {
for k := range batch.Positions { for k := range positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32) _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
} }
} }
return err return err
@@ -87,16 +80,6 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
} }
} }
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
for _, cache := range c.caches {
if !cache.CanResume(seq, pos) {
return false
}
}
return true
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches { for _, cache := range c.caches {

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c"; char const *LLAMA_COMMIT = "46e3556e01b824e52395fb050b29804b6cff2a7c";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";

View File

@@ -2,9 +2,6 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif #endif
#include "ggml.h"
#include "gguf.h"
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
@@ -73,22 +70,6 @@
#include <sys/syslimits.h> #include <sys/syslimits.h>
#endif #endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@@ -483,48 +464,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder); s = std::move(builder);
} }
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
return result.str();
}
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> parts;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
parts.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
parts.push_back(str.substr(start));
return parts;
}
std::string string_repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
std::string string_from(bool value) { std::string string_from(bool value) {
return value ? "true" : "false"; return value ? "true" : "false";
} }
@@ -907,7 +846,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} else if (!params.model_url.empty()) { } else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else { } else {
model = llama_model_load_from_file(params.model.c_str(), mparams); model = llama_load_model_from_file(params.model.c_str(), mparams);
} }
if (model == NULL) { if (model == NULL) {
@@ -915,28 +854,26 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.reranking) { if (params.reranking) {
bool ok = true; bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (!ok) { if (!ok) {
llama_model_free(model); llama_free_model(model);
return iparams; return iparams;
} }
@@ -944,10 +881,10 @@ struct common_init_result common_init_from_params(common_params & params) {
auto cparams = common_context_params_to_llama(params); auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_model_free(model); llama_free_model(model);
return iparams; return iparams;
} }
@@ -958,26 +895,25 @@ struct common_init_result common_init_from_params(common_params & params) {
if (!params.control_vectors.empty()) { if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
const auto cvec = common_control_vector_load(params.control_vectors); const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) { if (cvec.n_embd == -1) {
llama_free(lctx); llama_free(lctx);
llama_model_free(model); llama_free_model(model);
return iparams; return iparams;
} }
int err = llama_apply_adapter_cvec( int err = llama_control_vector_apply(lctx,
lctx, cvec.data.data(),
cvec.data.data(), cvec.data.size(),
cvec.data.size(), cvec.n_embd,
cvec.n_embd, params.control_vector_layer_start,
params.control_vector_layer_start, params.control_vector_layer_end);
params.control_vector_layer_end);
if (err) { if (err) {
llama_free(lctx); llama_free(lctx);
llama_model_free(model); llama_free_model(model);
return iparams; return iparams;
} }
@@ -985,12 +921,12 @@ struct common_init_result common_init_from_params(common_params & params) {
// load and optionally apply lora adapters // load and optionally apply lora adapters
for (auto & la : params.lora_adapters) { for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora; llama_lora_adapter_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str())); lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (lora == nullptr) { if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); llama_free(lctx);
llama_model_free(model); llama_free_model(model);
return iparams; return iparams;
} }
@@ -999,17 +935,17 @@ struct common_init_result common_init_from_params(common_params & params) {
} }
if (!params.lora_init_without_apply) { if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters); common_lora_adapters_apply(lctx, params.lora_adapters);
} }
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false; params.sampling.ignore_eos = false;
} }
if (params.sampling.ignore_eos) { if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { for (llama_token i = 0; i < llama_n_vocab(model); i++) {
if (llama_vocab_is_eog(vocab, i)) { if (llama_token_is_eog(model, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY}); params.sampling.logit_bias.push_back({i, -INFINITY});
} }
@@ -1030,9 +966,8 @@ struct common_init_result common_init_from_params(common_params & params) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
std::vector<llama_token> tmp; std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab); llama_token bos = llama_token_bos(model);
llama_token eos = llama_vocab_eos(vocab); llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token // some models (e.g. T5) don't have a BOS token
if (bos != LLAMA_TOKEN_NULL) { if (bos != LLAMA_TOKEN_NULL) {
tmp.push_back(bos); tmp.push_back(bos);
@@ -1047,7 +982,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_encoder(model)) { if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) { if (decoder_start_token_id == -1) {
decoder_start_token_id = bos; decoder_start_token_id = bos;
} }
tmp.clear(); tmp.clear();
@@ -1067,11 +1002,11 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) { void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_clear_adapter_lora(ctx); llama_lora_adapter_clear(ctx);
for (auto & la : lora) { for (auto & la : lora) {
if (la.scale != 0.0f) { if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale); llama_lora_adapter_set(ctx, la.ptr, la.scale);
} }
} }
} }
@@ -1085,6 +1020,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (params.n_gpu_layers != -1) { if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers; mparams.n_gpu_layers = params.n_gpu_layers;
} }
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu; mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode; mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;
@@ -1187,8 +1123,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl // Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) { if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__); LOG_ERR("%s: error initializing libcurl\n", __func__);
return false; return false;
@@ -1202,9 +1137,11 @@ static bool common_download_file(const std::string & url, const std::string & pa
// Check if hf-token or bearer-token was specified // Check if hf-token or bearer-token was specified
if (!hf_token.empty()) { if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token; std::string auth_header = "Authorization: Bearer ";
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); auth_header += hf_token.c_str();
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); struct curl_slist *http_headers = NULL;
http_headers = curl_slist_append(http_headers, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
} }
#if defined(_WIN32) #if defined(_WIN32)
@@ -1474,7 +1411,7 @@ struct llama_model * common_load_model_from_url(
} }
} }
return llama_model_load_from_file(local_path.c_str(), params); return llama_load_model_from_file(local_path.c_str(), params);
} }
struct llama_model * common_load_model_from_hf( struct llama_model * common_load_model_from_hf(
@@ -1500,80 +1437,6 @@ struct llama_model * common_load_model_from_hf(
return common_load_model_from_url(model_url, local_path, hf_token, params); return common_load_model_from_url(model_url, local_path, hf_token, params);
} }
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
json model_info;
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
model_info = json::parse(res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (!model_info.contains("ggufFile")) {
throw std::runtime_error("error: model does not have ggufFile");
}
json & gguf_file = model_info.at("ggufFile");
if (!gguf_file.contains("rfilename")) {
throw std::runtime_error("error: ggufFile does not have rfilename");
}
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
}
#else #else
struct llama_model * common_load_model_from_url( struct llama_model * common_load_model_from_url(
@@ -1595,11 +1458,6 @@ struct llama_model * common_load_model_from_hf(
return nullptr; return nullptr;
} }
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return std::make_pair("", "");
}
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
// //
@@ -1698,23 +1556,21 @@ std::vector<llama_token> common_tokenize(
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
const llama_model * model = llama_get_model(ctx); return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_tokenize(vocab, text, add_special, parse_special);
} }
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_vocab * vocab, const struct llama_model * model,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
// upper limit for the number of tokens // upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special; int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens); std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens); GGML_ASSERT(check == -n_tokens);
} else { } else {
result.resize(n_tokens); result.resize(n_tokens);
@@ -1723,18 +1579,12 @@ std::vector<llama_token> common_tokenize(
} }
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_token_to_piece(vocab, token, special);
}
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
std::string piece; std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) { if (n_chars < 0) {
piece.resize(-n_chars); piece.resize(-n_chars);
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars); GGML_ASSERT(check == -n_chars);
} }
else { else {
@@ -1744,19 +1594,13 @@ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token
return piece; return piece;
} }
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_detokenize(vocab, tokens, special);
}
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text; std::string text;
text.resize(std::max(text.capacity(), tokens.size())); text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) { if (n_chars < 0) {
text.resize(-n_chars); text.resize(-n_chars);
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
} }
@@ -1766,6 +1610,103 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text; return text;
} }
//
// Chat template utils
//
std::string common_get_builtin_chat_template(const struct llama_model * model) {
static const char * template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
if (res > 0) {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
return "";
}
bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass) {
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
}
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return common_chat_apply_template(model, tmpl, msgs, true);
}
// //
// KV cache utils // KV cache utils
// //

View File

@@ -4,7 +4,6 @@
#include "llama-cpp.h" #include "llama-cpp.h"
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
@@ -25,11 +24,11 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_adapter_lora_info { struct common_lora_adapter_info {
std::string path; std::string path;
float scale; float scale;
struct llama_adapter_lora * ptr; struct llama_lora_adapter * ptr;
}; };
using llama_tokens = std::vector<llama_token>; using llama_tokens = std::vector<llama_token>;
@@ -104,17 +103,6 @@ enum dimre_method {
DIMRE_METHOD_MEAN, DIMRE_METHOD_MEAN,
}; };
enum common_conversation_mode {
COMMON_CONVERSATION_MODE_DISABLED = 0,
COMMON_CONVERSATION_MODE_ENABLED = 1,
COMMON_CONVERSATION_MODE_AUTO = 2,
};
struct common_grammar_trigger {
std::string word;
bool at_start;
};
// sampling parameters // sampling parameters
struct common_params_sampling { struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -140,7 +128,6 @@ struct common_params_sampling {
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false; bool ignore_eos = false;
@@ -161,11 +148,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE, COMMON_SAMPLER_TYPE_TEMPERATURE,
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@@ -178,19 +161,15 @@ struct common_params_speculative {
int32_t n_ctx = 0; // draft context size int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy) float p_min = 0.9f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
std::string hf_repo = ""; // HF repo // NOLINT std::string model = ""; // draft model for speculative decoding // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
}; };
struct common_params_vocoder { struct common_params_vocoder {
@@ -199,13 +178,6 @@ struct common_params_vocoder {
std::string model = ""; // model path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT std::string model_url = ""; // model url to download // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
}; };
struct common_params { struct common_params {
@@ -268,13 +240,14 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
std::vector<std::string> in_files; // all input files std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
@@ -298,11 +271,11 @@ struct common_params {
bool kl_divergence = false; // compute KL divergence bool kl_divergence = false; // compute KL divergence
bool usage = false; // print usage bool usage = false; // print usage
bool completion = false; // print source-able completion script
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool special = false; // enable special token output bool special = false; // enable special token output
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool interactive_first = false; // wait for user input immediately bool interactive_first = false; // wait for user input immediately
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
@@ -328,8 +301,6 @@ struct common_params {
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava) // multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector // NOLINT std::string mmproj = ""; // path to multimodal projector // NOLINT
std::vector<std::string> image; // path to image file(s) std::vector<std::string> image; // path to image file(s)
@@ -351,9 +322,7 @@ struct common_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
@@ -432,13 +401,13 @@ bool set_process_priority(enum ggml_sched_priority prio);
// //
#ifdef __GNUC__ #ifdef __GNUC__
# if defined(__MINGW32__) && !defined(__clang__) #ifdef __MINGW32__
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else #else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) #define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif #endif
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
@@ -447,10 +416,6 @@ std::string string_format(const char * fmt, ...);
std::string string_strip(const std::string & str); std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp(); std::string string_get_sortable_timestamp();
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace); void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
template<class T> template<class T>
@@ -489,11 +454,6 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0; return str.rfind(prefix, 0) == 0;
} }
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);
@@ -521,7 +481,7 @@ struct common_init_result {
llama_model_ptr model; llama_model_ptr model;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora; std::vector<llama_lora_adapter_ptr> lora;
}; };
struct common_init_result common_init_from_params(common_params & params); struct common_init_result common_init_from_params(common_params & params);
@@ -535,7 +495,6 @@ struct llama_model * common_load_model_from_url(
const std::string & local_path, const std::string & local_path,
const std::string & hf_token, const std::string & hf_token,
const struct llama_model_params & params); const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf( struct llama_model * common_load_model_from_hf(
const std::string & repo, const std::string & repo,
const std::string & remote_path, const std::string & remote_path,
@@ -543,12 +502,8 @@ struct llama_model * common_load_model_from_hf(
const std::string & hf_token, const std::string & hf_token,
const struct llama_model_params & params); const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora); void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);
// //
// Batch utils // Batch utils
@@ -586,7 +541,7 @@ std::vector<llama_token> common_tokenize(
bool parse_special = false); bool parse_special = false);
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_vocab * vocab, const struct llama_model * model,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special = false); bool parse_special = false);
@@ -598,23 +553,48 @@ std::string common_token_to_piece(
llama_token token, llama_token token,
bool special = true); bool special = true);
std::string common_token_to_piece(
const struct llama_vocab * vocab,
llama_token token,
bool special = true);
// detokenizes a vector of tokens into a string // detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode` // should work similar to Python's `tokenizer.decode`
// optionally renders special/control tokens // optionally renders special/control tokens
std::string common_detokenize( std::string common_detokenize(
const struct llama_context * ctx, llama_context * ctx,
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);
std::string common_detokenize( //
const struct llama_vocab * vocab, // Chat template utils
const std::vector<llama_token> & tokens, //
bool special = true);
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
};
// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass);
// Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
// //
// KV cache utils // KV cache utils

View File

@@ -1,6 +1,4 @@
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "common.h"
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <map> #include <map>
@@ -13,6 +11,11 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
template <typename Iterator>
static std::string join(Iterator begin, Iterator end, const std::string & separator);
static std::string repeat(const std::string & str, size_t n);
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max(); auto has_max = max_items != std::numeric_limits<int>::max();
@@ -125,8 +128,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (sub_len > 0) { if (sub_len > 0) {
auto from_sub = from.substr(i + 1); auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1); auto to_sub = to.substr(i + 1);
auto sub_zeros = string_repeat("0", sub_len); auto sub_zeros = repeat("0", sub_len);
auto sub_nines = string_repeat("9", sub_len); auto sub_nines = repeat("9", sub_len);
auto to_reached = false; auto to_reached = false;
out << "("; out << "(";
@@ -185,8 +188,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
auto max_digits = max_s.length(); auto max_digits = max_s.length();
for (auto digits = min_digits; digits < max_digits; digits++) { for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, string_repeat("9", digits)); uniform_range(min_s, repeat("9", digits));
min_s = "1" + string_repeat("0", digits); min_s = "1" + repeat("0", digits);
out << " | "; out << " | ";
} }
uniform_range(min_s, max_s); uniform_range(min_s, max_s);
@@ -315,6 +318,49 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) { static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match; std::smatch match;
std::string result; std::string result;
@@ -343,7 +389,6 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter { class SchemaConverter {
private: private:
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json; std::function<json(const std::string &)> _fetch_json;
bool _dotall; bool _dotall;
std::unordered_map<std::string, std::string> _rules; std::unordered_map<std::string, std::string> _rules;
@@ -373,7 +418,7 @@ private:
for (size_t i = 0; i < alt_schemas.size(); i++) { for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
} }
return string_join(rules, " | "); return join(rules.begin(), rules.end(), " | ");
} }
std::string _visit_pattern(const std::string & pattern, const std::string & name) { std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -436,7 +481,7 @@ private:
for (const auto & item : ret) { for (const auto & item : ret) {
results.push_back(to_rule(item)); results.push_back(to_rule(item));
} }
return std::make_pair(string_join(results, " "), false); return std::make_pair(join(results.begin(), results.end(), " "), false);
}; };
while (i < length) { while (i < length) {
@@ -494,7 +539,7 @@ private:
} }
curly_brackets += '}'; curly_brackets += '}';
i++; i++;
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0; int min_times = 0;
int max_times = std::numeric_limits<int>::max(); int max_times = std::numeric_limits<int>::max();
try { try {
@@ -764,11 +809,10 @@ private:
public: public:
SchemaConverter( SchemaConverter(
const std::function<json(const std::string &)> & fetch_json, const std::function<json(const std::string &)> & fetch_json,
bool dotall, bool dotall)
bool compact_spaces)
: _fetch_json(fetch_json), _dotall(dotall) : _fetch_json(fetch_json), _dotall(dotall)
{ {
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; _rules["space"] = SPACE_RULE;
} }
void resolve_refs(json & schema, const std::string & url) { void resolve_refs(json & schema, const std::string & url) {
@@ -810,7 +854,7 @@ public:
return; return;
} }
std::string pointer = ref.substr(ref.find('#') + 1); std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = string_split(pointer, "/"); std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) { for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i]; std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) { if (target.is_null() || !target.contains(sel)) {
@@ -861,7 +905,7 @@ public:
for (const auto & v : schema["enum"]) { for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v)); enum_values.push_back(_generate_constant_rule(v));
} }
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object") } else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") || && (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -975,10 +1019,10 @@ public:
void check_errors() { void check_errors() {
if (!_errors.empty()) { if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
} }
if (!_warnings.empty()) { if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
} }
} }
@@ -991,35 +1035,11 @@ public:
} }
}; };
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { std::string json_schema_to_grammar(const json & schema) {
#ifdef LLAMA_USE_LLGUIDANCE SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
if (!force_gbnf) { auto copy = schema;
return "%llguidance {}\nstart: %json " + schema.dump(); converter.resolve_refs(copy, "input");
} converter.visit(copy, "");
#else
(void)force_gbnf;
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const common_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
},
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
return converter.visit(schema, name == "root" ? "" : name);
},
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
converter.resolve_refs(schema, "");
}
};
cb(builder);
converter.check_errors(); converter.check_errors();
return converter.format_grammar(); return converter.format_grammar();
} }

View File

@@ -5,18 +5,4 @@
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
bool force_gbnf = false);
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs;
};
struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View File

@@ -1,6 +1,5 @@
#include "log.h" #include "log.h"
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>
@@ -15,6 +14,16 @@ void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity; common_log_verbosity_thold = verbosity;
} }
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
static int64_t t_us() { static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
} }
@@ -197,7 +206,6 @@ public:
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
} }
#endif #endif
va_end(args_copy);
} }
entry.level = level; entry.level = level;

View File

@@ -2,20 +2,9 @@
#include "ggml.h" // for ggml_log_level #include "ggml.h" // for ggml_log_level
#define LOG_CLR_TO_EOL "\033[K\r"
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
#ifndef __GNUC__ #ifndef __GNUC__
# define LOG_ATTRIBUTE_FORMAT(...) # define LOG_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__) && !defined(__clang__) #elif defined(__MINGW32__)
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else #else
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))

View File

@@ -113,10 +113,7 @@ struct common_sampler {
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
@@ -134,47 +131,24 @@ std::string common_params_sampling::print() const {
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau); mirostat, mirostat_eta, mirostat_tau);
return std::string(result); return std::string(result);
} }
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
struct llama_sampler * grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ grmr, /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},
@@ -183,62 +157,56 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias( llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab), llama_n_vocab(model),
params.logit_bias.size(), params.logit_bias.size(),
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
if (params.top_n_sigma >= 0) { for (const auto & cnstr : params.samplers) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); switch (cnstr) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); case COMMON_SAMPLER_TYPE_DRY:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); {
} else { std::vector<const char *> c_breakers;
for (const auto & cnstr : params.samplers) { c_breakers.reserve(params.dry_sequence_breakers.size());
switch (cnstr) { for (const auto & str : params.dry_sequence_breakers) {
case COMMON_SAMPLER_TYPE_DRY: c_breakers.push_back(str.c_str());
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break;
case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); }
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break; break;
default: case COMMON_SAMPLER_TYPE_PENALTIES:
GGML_ASSERT(false && "unknown sampler type"); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
} break;
default:
GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View File

@@ -102,6 +102,3 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars); std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

View File

@@ -7,7 +7,6 @@
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "gguf.h"
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
#include "ggml-cuda.h" #include "ggml-cuda.h"
@@ -40,7 +39,6 @@
#include <map> #include <map>
#include <regex> #include <regex>
#include <stdexcept> #include <stdexcept>
#include <unordered_set>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <cinttypes> #include <cinttypes>
@@ -116,7 +114,6 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_GELU "clip.use_gelu"
@@ -134,7 +131,6 @@ static std::string format(const char * fmt, ...) {
#define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -176,15 +172,6 @@ static std::string format(const char * fmt, ...) {
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" #define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s" #define TN_MINICPMV_LN "resampler.ln_%s.%s"
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
#define TN_GLM_BOI_W "adapter.boi"
#define TN_GLM_EOI_W "adapter.eoi"
enum projector_type { enum projector_type {
PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP,
@@ -192,7 +179,6 @@ enum projector_type {
PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_RESAMPLER, PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@@ -202,7 +188,6 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_RESAMPLER, "resampler"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
}; };
@@ -290,7 +275,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
{ {
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
int arr_n = gguf_get_arr_n(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i);
const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); const void * data = gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
for (int j = 0; j < arr_n; j++) { for (int j = 0; j < arr_n; j++) {
@@ -459,9 +444,8 @@ struct clip_hparams {
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
std::vector<int32_t> image_grid_pinpoints; int32_t image_grid_pinpoints[32];
int32_t image_crop_resolution; int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
}; };
struct clip_layer { struct clip_layer {
@@ -528,12 +512,6 @@ struct clip_vision_model {
struct ggml_tensor * mm_4_w = NULL; struct ggml_tensor * mm_4_w = NULL;
struct ggml_tensor * mm_4_b = NULL; struct ggml_tensor * mm_4_b = NULL;
//GLMV-Edge projection
struct ggml_tensor * mm_model_adapter_conv_w;
struct ggml_tensor * mm_model_adapter_conv_b;
struct ggml_tensor * boi_w;
struct ggml_tensor * eoi_w;
// MobileVLM projection // MobileVLM projection
struct ggml_tensor * mm_model_mlp_1_w; struct ggml_tensor * mm_model_mlp_1_w;
struct ggml_tensor * mm_model_mlp_1_b; struct ggml_tensor * mm_model_mlp_1_b;
@@ -594,14 +572,12 @@ struct clip_ctx {
bool has_vision_encoder = false; bool has_vision_encoder = false;
bool has_llava_projector = false; bool has_llava_projector = false;
bool has_minicpmv_projector = false; bool has_minicpmv_projector = false;
bool has_glm_projector = false;
bool has_qwen2vl_merger = false; bool has_qwen2vl_merger = false;
int minicpmv_version = 2; int minicpmv_version = 2;
struct clip_vision_model vision_model; struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP; projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
float image_mean[3]; float image_mean[3];
float image_std[3]; float image_std[3];
bool use_gelu = false; bool use_gelu = false;
@@ -668,12 +644,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int hidden_size = hparams.hidden_size; const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer;
const float eps = hparams.eps; const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs->size; const int batch_size = imgs->size;
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { if (ctx->has_llava_projector || ctx->has_minicpmv_projector) {
GGML_ASSERT(batch_size == 1); GGML_ASSERT(batch_size == 1);
} }
@@ -753,9 +730,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
else if (ctx->minicpmv_version == 3) { else if (ctx->minicpmv_version == 3) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
} }
else if (ctx->minicpmv_version == 4) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
ggml_set_name(pos_embed, "pos_embed"); ggml_set_name(pos_embed, "pos_embed");
ggml_set_input(pos_embed); ggml_set_input(pos_embed);
} }
@@ -768,19 +742,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
} }
std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer;
// loop over layers // loop over layers
for (int il = 0; il < ctx->max_feature_layer; il++) { if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
// TODO: figure out why we doing thing in this way ???
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
//const size_t nb_q_w = model.layers[il].q_w->nb[0]; //const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1 // layernorm1
@@ -868,6 +837,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur); cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur; embeddings = cur;
} }
// post-layernorm // post-layernorm
@@ -878,19 +848,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
} }
// final layer is a vision feature layer
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
// If feature layers are explicitly set, stack them (if we have multiple)
if (!embedding_stack.empty()) {
embeddings = embedding_stack[0];
for (size_t i = 1; i < embedding_stack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
}
}
// llava projector // llava projector
if (ctx->has_llava_projector) { if (ctx->has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@@ -1108,11 +1065,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
n_head = hidden_size/d_head; n_head = hidden_size/d_head;
num_query = 64; num_query = 64;
} }
else if (ctx->minicpmv_version == 4) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@@ -1147,33 +1099,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
GGML_ASSERT(false); GGML_ASSERT(false);
} }
} }
// glm projector else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
else if (ctx->has_glm_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
//GLU
{
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings);
struct ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
} else {
GGML_ABORT("fatel error");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@@ -1342,11 +1268,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
} }
idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ);
if (idx != -1) {
new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx);
}
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
if (idx != -1) { if (idx != -1) {
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
@@ -1371,7 +1292,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector);
LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector);
LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
} }
@@ -1482,26 +1402,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx); int n = gguf_get_arr_n(ctx, idx);
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < n; ++i) { for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
hparams.image_grid_pinpoints.push_back(pinpoints[i]); hparams.image_grid_pinpoints[i] = pinpoints[i];
} }
} catch (std::runtime_error & /*e*/) { } if (n < 32)
hparams.image_grid_pinpoints[n] = 0;
// Load the vision feature layer indices if they are explicitly provided; } catch (std::runtime_error & /*e*/) {
// if multiple vision feature layers are present, the values will be concatenated hparams.image_grid_pinpoints[0]=0;
// to form the final visual features. }
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
try {
int idx = get_key_idx(ctx, KEY_FEATURE_LAYER);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < n; ++i) {
hparams.vision_feature_layer.insert(vision_feature_layer[i]);
}
} catch (std::runtime_error & /*e*/) { }
try { try {
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
@@ -1527,9 +1435,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->image_std[i] = std_data[i]; new_clip->image_std[i] = std_data[i];
} }
// Calculate the deepest feature layer based on hparams and projector type
new_clip->max_feature_layer = get_deepest_feature_layer(new_clip);
if (verbosity >= 2) { if (verbosity >= 2) {
LOG_INF("\n%s: vision model hparams\n", __func__); LOG_INF("\n%s: vision model hparams\n", __func__);
LOG_INF("image_size %d\n", hparams.image_size); LOG_INF("image_size %d\n", hparams.image_size);
@@ -1543,13 +1448,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
LOG_INF("v_image_grid_pinpoints: "); LOG_INF("v_image_grid_pinpoints: ");
for (const auto & pp : hparams.image_grid_pinpoints) { for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
LOG_INF("%d ", pp); LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
}
LOG_INF("\n");
LOG_INF("v_vision_feature_layer: ");
for (const auto & feature_layer: hparams.vision_feature_layer) {
LOG_INF("%d ", feature_layer);
} }
LOG_INF("\n"); LOG_INF("\n");
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
@@ -1684,18 +1584,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
} }
else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight"));
vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias"));
vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight"));
vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight"));
vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias"));
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight"));
vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight"));
vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W);
vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W);
}
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
@@ -1788,11 +1676,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
} }
} }
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) { static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
img->nx = nx; img->nx = nx;
img->ny = ny; img->ny = ny;
img->buf.resize(3 * nx * ny); img->buf.resize(3 * nx * ny);
memcpy(img->buf.data(), rgb_pixels, img->buf.size()); memcpy(img->buf.data(), data, img->buf.size());
} }
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
@@ -1802,7 +1690,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
return false; return false;
} }
clip_build_img_from_pixels(data, nx, ny, img); build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data); stbi_image_free(data);
return true; return true;
} }
@@ -1814,7 +1702,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
LOG_ERR("%s: failed to decode image bytes\n", __func__); LOG_ERR("%s: failed to decode image bytes\n", __func__);
return false; return false;
} }
clip_build_img_from_pixels(data, nx, ny, img); build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data); stbi_image_free(data);
return true; return true;
} }
@@ -2170,7 +2058,6 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
images[images.size()-1].push_back(patch); images[images.size()-1].push_back(patch);
} }
} }
clip_image_u8_free(refine_image);
} }
return images; return images;
} }
@@ -2209,13 +2096,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
clip_image_f32_free(res); clip_image_f32_free(res);
} }
} }
for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t j = 0; j < imgs[i].size(); ++j) {
if (imgs[i][j] != nullptr) {
clip_image_u8_free(imgs[i][j]);
}
}
}
return true; return true;
} }
else if (ctx->has_qwen2vl_merger) { else if (ctx->has_qwen2vl_merger) {
@@ -2236,20 +2116,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
return true; return true;
} }
if (ctx->has_glm_projector) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
int32_t sz=ctx->vision_model.hparams.image_size;
bicubic_resize(*img, resized_image,sz,sz);
clip_image_f32 * res = clip_image_f32_init();
//clip_image_save_to_bmp(resized_image, "resized.bmp");
normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std);
res_imgs->data[0] = *res;
clip_image_f32_free(res);
return true;
}
bool pad_to_square = true; bool pad_to_square = true;
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n"); LOG_ERR("This gguf file seems to have no vision encoder\n");
@@ -2294,10 +2160,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
} }
} }
} else { } else {
if (!params.image_grid_pinpoints.empty()) { if (params.image_grid_pinpoints[0] != 0) {
// "spatial_unpad" with "anyres" processing for llava-1.6 // "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions; std::vector<std::pair<int, int>> possible_resolutions;
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
} }
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
@@ -2435,8 +2301,7 @@ void clip_free(clip_ctx * ctx) {
} }
size_t clip_embd_nbytes(const struct clip_ctx * ctx) { size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
int extra_tokens = ctx->has_glm_projector ? 2 : 0; return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float);
} }
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
@@ -2463,14 +2328,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
} }
const int32_t * clip_image_grid(const struct clip_ctx * ctx) { const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { return ctx->vision_model.hparams.image_grid_pinpoints;
return &ctx->vision_model.hparams.image_grid_pinpoints.front();
}
return nullptr;
}
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints.size();
} }
int clip_n_patches(const struct clip_ctx * ctx) { int clip_n_patches(const struct clip_ctx * ctx) {
@@ -2485,7 +2343,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
n_patches /= 4; n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
if (ctx->minicpmv_version == 2) { if (ctx->minicpmv_version == 2) {
@@ -2494,9 +2352,6 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
else if (ctx->minicpmv_version == 3) { else if (ctx->minicpmv_version == 3) {
n_patches = 64; n_patches = 64;
} }
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
int patch_size = params.patch_size * 2; int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
@@ -2618,12 +2473,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
if (ctx->has_minicpmv_projector) { if (ctx->has_minicpmv_projector) {
GGML_ASSERT(batch_size == 1); GGML_ASSERT(batch_size == 1);
} }
if (ctx->has_glm_projector) {
GGML_ASSERT(batch_size == 1);
ggml_tensor * boi = ctx->vision_model.boi_w;
ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi));
vec = (float*)(vec+ggml_nelements(boi)); //offset for boi
}
// build the inference graph // build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
@@ -2682,8 +2531,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions)); int* positions_data = (int*)malloc(ggml_nbytes(positions));
int bucket_coords_h[1024]; int bucket_coords_h[70];
int bucket_coords_w[1024]; int bucket_coords_w[70];
for (int i = 0; i < pos_h; i++){ for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h); bucket_coords_h[i] = std::floor(70.0*i/pos_h);
} }
@@ -2711,9 +2560,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->minicpmv_version == 3) { else if (ctx->minicpmv_version == 3) {
embed_dim = 3584; embed_dim = 3584;
} }
else if (ctx->minicpmv_version == 4) {
embed_dim = 3584;
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2776,15 +2622,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data); free(positions_data);
if (!ctx->has_glm_projector) { {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// The patches vector is used to get rows to index into the embeds with;
// we should skip dim 0 only if we have CLS to avoid going out of bounds
// when retrieving the rows.
int patch_offset = ctx->has_class_embedding ? 1 : 0;
int* patches_data = (int*)malloc(ggml_nbytes(patches)); int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) { for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + patch_offset; patches_data[i] = i + 1;
} }
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data); free(patches_data);
@@ -2804,19 +2646,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// copy the embeddings to the location passed by the user // copy the embeddings to the location passed by the user
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
if (ctx->has_glm_projector) {
//eoi
ggml_tensor * eoi = ctx->vision_model.eoi_w;
int offset = ggml_nelements(embeddings);
ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi));
}
return true; return true;
} }
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
ggml_type type = GGML_TYPE_Q4_1;
assert(itype < GGML_TYPE_COUNT); assert(itype < GGML_TYPE_COUNT);
ggml_type type = static_cast<ggml_type>(itype); type = static_cast<ggml_type>(itype);
auto * ctx_clip = clip_model_load(fname_inp, 2); auto * ctx_clip = clip_model_load(fname_inp, 2);
@@ -2869,8 +2706,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
} }
} }
// quantize only 2D tensors and bigger than block size // quantize only 2D tensors
quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); quantize &= (ggml_n_dims(cur) == 2);
if (quantize) { if (quantize) {
new_type = type; new_type = type;
@@ -2915,8 +2752,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
total_size_org += orig_size; total_size_org += orig_size;
total_size_new += new_size; total_size_new += new_size;
gguf_set_tensor_type(ctx_out, name.c_str(), new_type); gguf_set_tensor_type(ctx_out, name.c_str(), new_type);
GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size);
gguf_set_tensor_data(ctx_out, name.c_str(), new_data);
fout.write((const char *)new_data, new_size); fout.write((const char *)new_data, new_size);
size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
for (size_t j = 0; j < pad; ++j) { for (size_t j = 0; j < pad; ++j) {
@@ -2966,12 +2802,6 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
else if (ctx->minicpmv_version == 3) { else if (ctx->minicpmv_version == 3) {
return 3584; return 3584;
} }
else if (ctx->minicpmv_version == 4) {
return 3584;
}
}
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
} }
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0]; return ctx->vision_model.mm_1_b->ne[0];
@@ -2988,35 +2818,10 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
return 0; return 0;
} }
bool clip_is_glm(const struct clip_ctx * ctx) {
return ctx->has_glm_projector;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) { bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger; return ctx->has_qwen2vl_merger;
} }
// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
// Get the index of the second to last layer; this is the
// default for models that have a llava projector
const auto & hparams = ctx->vision_model.hparams;
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;
// Handle other projectors; incrementing here indicates that we
// should use the last encoder layer for the vision features.
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
// If we set explicit vision feature layers, only go up to the deepest one
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img; clip_image_f32 clip_img;

View File

@@ -55,7 +55,6 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
@@ -74,9 +73,6 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */ /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
@@ -93,14 +89,10 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
return true; return true;
} }
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
int width = image->nx; int width = image->nx;
int height = image->ny; int height = image->ny;
int num_patches = (height / patch_size) * (width / patch_size); int num_patches = (height / patch_size) * (width / patch_size);
@@ -277,7 +277,13 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
} }
else { else {
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
} }
if (!encoded) { if (!encoded) {
@@ -307,23 +313,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
load_image_size->height = img->ny; load_image_size->height = img->ny;
clip_add_load_image_size(ctx_clip, load_image_size); clip_add_load_image_size(ctx_clip, load_image_size);
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
delete[] img_res_v.data;
img_res_v.size = 0;
img_res_v.data = nullptr;
}
else if (clip_is_glm(ctx_clip)){
struct clip_image_size * load_image_size = clip_image_size_init();
load_image_size->width = img_res_v.data[0].nx;
load_image_size->height = img_res_v.data[0].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
*n_img_pos = (pos * pos + 2);
if (!encoded){
LOG_ERR("Unable to encode image \n");
return false;
}
} }
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding // flat / default llava-1.5 type embedding
@@ -353,10 +342,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
const int32_t * image_grid = clip_image_grid(ctx_clip); const int32_t * image_grid = clip_image_grid(ctx_clip);
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
std::vector<std::pair<int, int>> grid_pinpoints; std::vector<std::pair<int, int>> grid_pinpoints;
for (size_t i = 0; i < num_gridpoints; i += 2) { for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
} }
@@ -396,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples // make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip); auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) { if (n_image_embd != n_llama_embd) {
LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
@@ -406,14 +394,10 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
} }
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
// Granite vision uses up to 10 patches + base patch int num_max_patches = 6;
int num_max_patches = 11;
if (clip_is_minicpmv(ctx_clip)) { if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10; num_max_patches = 10;
} }
if (clip_is_glm(ctx_clip)) {
num_max_patches = 1;
}
float * image_embd; float * image_embd;
if (clip_is_qwen2vl(ctx_clip)) { if (clip_is_qwen2vl(ctx_clip)) {
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed. // qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
@@ -473,7 +457,7 @@ struct llava_embd_batch {
}; };
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i; int n_eval = image_embed->n_image_pos - i;

View File

@@ -9,7 +9,7 @@
#include "llama.h" #include "llama.h"
struct llama_model_deleter { struct llama_model_deleter {
void operator()(llama_model * model) { llama_model_free(model); } void operator()(llama_model * model) { llama_free_model(model); }
}; };
struct llama_context_deleter { struct llama_context_deleter {
@@ -20,11 +20,11 @@ struct llama_sampler_deleter {
void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); }
}; };
struct llama_adapter_lora_deleter { struct llama_lora_adapter_deleter {
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); }
}; };
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr; typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr; typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr; typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr; typedef std::unique_ptr<llama_lora_adapter, llama_lora_adapter_deleter> llama_lora_adapter_ptr;

View File

@@ -34,6 +34,7 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
// TODO: use everywhere in the implementation
#define LLAMA_TOKEN_NULL -1 #define LLAMA_TOKEN_NULL -1
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
@@ -56,7 +57,7 @@ extern "C" {
// TODO: show sample usage // TODO: show sample usage
// //
struct llama_vocab; // struct llama_vocab; // TODO: add in the future
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
@@ -105,7 +106,6 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
}; };
enum llama_rope_type { enum llama_rope_type {
@@ -214,7 +214,7 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
}; };
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data { typedef struct llama_token_data {
llama_token id; // token id llama_token id; // token id
float logit; // log-odds of the token float logit; // log-odds of the token
@@ -290,6 +290,9 @@ extern "C" {
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split; const float * tensor_split;
// comma separated list of RPC servers to use for offloading
const char * rpc_servers;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues. // If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted. // If it returns false, model loading is immediately aborted.
@@ -309,7 +312,7 @@ extern "C" {
}; };
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggml-org/llama.cpp/pull/7544 // https://github.com/ggerganov/llama.cpp/pull/7544
struct llama_context_params { struct llama_context_params {
uint32_t n_ctx; // text context, 0 = from model uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
@@ -322,7 +325,7 @@ extern "C" {
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_attention_type attention_type; // attention type to use for embeddings
// ref: https://github.com/ggml-org/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
@@ -385,10 +388,11 @@ extern "C" {
} llama_chat_message; } llama_chat_message;
// lora adapter // lora adapter
struct llama_adapter_lora; // TODO: rename to llama_adapter_lora
struct llama_lora_adapter;
// Helpers for getting default parameters // Helpers for getting default parameters
// TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
@@ -399,53 +403,31 @@ extern "C" {
// Call once at the start of the program // Call once at the start of the program
LLAMA_API void llama_backend_init(void); LLAMA_API void llama_backend_init(void);
// Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void);
//optional: //optional:
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
// Optional: an auto threadpool gets created in ggml if not passed explicitly // Optional: an auto threadpool gets created in ggml if not passed explicitly
LLAMA_API void llama_attach_threadpool( LLAMA_API void llama_attach_threadpool(
struct llama_context * ctx, struct llama_context * ctx,
ggml_threadpool_t threadpool, ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch); ggml_threadpool_t threadpool_batch);
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( // Call once at the end of the program - currently only used for MPI
const char * path_model, LLAMA_API void llama_backend_free(void);
struct llama_model_params params),
"use llama_model_load_from_file instead");
// Load the model from a file LLAMA_API struct llama_model * llama_load_model_from_file(
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
// If the split file name does not follow this pattern, use llama_model_load_from_splits
LLAMA_API struct llama_model * llama_model_load_from_file(
const char * path_model, const char * path_model,
struct llama_model_params params); struct llama_model_params params);
// Load the model from multiple splits (support custom naming scheme) // TODO: rename to llama_model_free
// The paths must be in the correct order LLAMA_API void llama_free_model(struct llama_model * model);
LLAMA_API struct llama_model * llama_model_load_from_splits(
const char ** paths,
size_t n_paths,
struct llama_model_params params);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), // TODO: rename to llama_init_from_model
"use llama_model_free instead"); LLAMA_API struct llama_context * llama_new_context_with_model(
LLAMA_API void llama_model_free(struct llama_model * model);
LLAMA_API struct llama_context * llama_init_from_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params); struct llama_context_params params);
DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params),
"use llama_init_from_model instead");
// TODO (jmorganca): this should most likely be passed in as part of a batch // TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches. // and not set on the context for all batches.
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
@@ -467,31 +449,20 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
// Functions to access the model's GGUF metadata scalar values // Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure // - The functions return the length of the string on success, or -1 on failure
@@ -517,10 +488,6 @@ extern "C" {
// Returns the total size of all the tensors in the model in bytes // Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model); LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
// If name is NULL, returns the default chat template
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
// Returns the total number of parameters in the model // Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
@@ -548,31 +515,34 @@ extern "C" {
// //
// Load a LoRA adapter from file // Load a LoRA adapter from file
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( // TODO: rename to llama_adapter_lora_init
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
struct llama_model * model, struct llama_model * model,
const char * path_lora); const char * path_lora);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// The following functions operate on a llama_context, hence the naming: llama_verb_...
// Add a loaded LoRA adapter to given context // Add a loaded LoRA adapter to given context
// This will not modify model's weight // This will not modify model's weight
LLAMA_API int32_t llama_set_adapter_lora( // TODO: rename to llama_set_adapter_lora
LLAMA_API int32_t llama_lora_adapter_set(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_adapter_lora * adapter, struct llama_lora_adapter * adapter,
float scale); float scale);
// Remove a specific LoRA adapter from given context // Remove a specific LoRA adapter from given context
// Return -1 if the adapter is not present in the context // Return -1 if the adapter is not present in the context
LLAMA_API int32_t llama_rm_adapter_lora( // TODO: rename to llama_rm_adapter_lora
LLAMA_API int32_t llama_lora_adapter_remove(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_adapter_lora * adapter); struct llama_lora_adapter * adapter);
// Remove all LoRA adapters from given context // Remove all LoRA adapters from given context
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); // TODO: rename to llama_clear_adapter_lora
LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
// TODO: rename to llama_adapter_lora_free
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
// Apply a loaded control vector to a llama_context, or if data is NULL, clear // Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector. // the currently loaded vector.
@@ -580,8 +550,9 @@ extern "C" {
// to an n_embd x n_layers buffer starting from layer 1. // to an n_embd x n_layers buffer starting from layer 1.
// il_start and il_end are the layer range the vector should apply to (both inclusive) // il_start and il_end are the layer range the vector should apply to (both inclusive)
// See llama_control_vector_load in common to load a control vector. // See llama_control_vector_load in common to load a control vector.
LLAMA_API int32_t llama_apply_adapter_cvec( // TODO: rename to llama_adapter_cvec_apply
struct llama_context * ctx, LLAMA_API int32_t llama_control_vector_apply(
struct llama_context * lctx,
const float * data, const float * data,
size_t len, size_t len,
int32_t n_embd, int32_t n_embd,
@@ -937,60 +908,41 @@ extern "C" {
// Vocab // Vocab
// //
LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
// Identify if Token Id is a control token or a render-able token // Identify if Token Id is a control token or a render-able token
LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); // infill tokens
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead");
DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead");
DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead");
DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead");
DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead");
DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead");
DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead");
DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
// CLS is equivalent to BOS
DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification
"use llama_vocab_bos instead");
// //
// Tokenization // Tokenization
@@ -1006,7 +958,7 @@ extern "C" {
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space. /// as plaintext. Does not insert a leading space.
LLAMA_API int32_t llama_tokenize( LLAMA_API int32_t llama_tokenize(
const struct llama_vocab * vocab, const struct llama_model * model,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
@@ -1020,7 +972,7 @@ extern "C" {
// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
// @param special If true, special tokens are rendered in the output. // @param special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_token_to_piece( LLAMA_API int32_t llama_token_to_piece(
const struct llama_vocab * vocab, const struct llama_model * model,
llama_token token, llama_token token,
char * buf, char * buf,
int32_t length, int32_t length,
@@ -1034,7 +986,7 @@ extern "C" {
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
/// @param unparse_special If true, special tokens are rendered in the output. /// @param unparse_special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_detokenize( LLAMA_API int32_t llama_detokenize(
const struct llama_vocab * vocab, const struct llama_model * model,
const llama_token * tokens, const llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,
char * text, char * text,
@@ -1048,7 +1000,7 @@ extern "C" {
/// Apply chat template. Inspired by hf apply_chat_template() on python. /// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead. /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead.
/// @param chat Pointer to a list of multiple llama_chat_message /// @param chat Pointer to a list of multiple llama_chat_message
/// @param n_msg Number of llama_chat_message in this chat /// @param n_msg Number of llama_chat_message in this chat
@@ -1057,6 +1009,7 @@ extern "C" {
/// @param length The size of the allocated buffer /// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template( LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
@@ -1104,6 +1057,7 @@ extern "C" {
// llama_sampler_free(smpl); // llama_sampler_free(smpl);
// //
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
// //
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
@@ -1122,12 +1076,11 @@ extern "C" {
}; };
struct llama_sampler { struct llama_sampler {
const struct llama_sampler_i * iface; struct llama_sampler_i * iface;
llama_sampler_context_t ctx; llama_sampler_context_t ctx;
}; };
// mirror of llama_sampler_i: // mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
@@ -1157,7 +1110,7 @@ extern "C" {
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@@ -1165,7 +1118,7 @@ extern "C" {
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
/// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
@@ -1180,9 +1133,6 @@ extern "C" {
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1207,22 +1157,10 @@ extern "C" {
float eta); float eta);
LLAMA_API struct llama_sampler * llama_sampler_init_grammar( LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab, const struct llama_model * model,
const char * grammar_str, const char * grammar_str,
const char * grammar_root); const char * grammar_root);
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties( LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
@@ -1231,9 +1169,8 @@ extern "C" {
float penalty_present); // 0.0 = disabled float penalty_present); // 0.0 = disabled
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry( LLAMA_API struct llama_sampler * llama_sampler_init_dry(
const struct llama_vocab * vocab, const struct llama_model * model,
int32_t n_ctx_train,
float dry_multiplier, float dry_multiplier,
float dry_base, float dry_base,
int32_t dry_allowed_length, int32_t dry_allowed_length,
@@ -1267,7 +1204,7 @@ extern "C" {
// 3. discard non-EOG tokens with low prob // 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT // 4. if no tokens are left -> pick EOT
// //
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

View File

@@ -1,7 +1,5 @@
#include "llama-adapter.h" #include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include <algorithm> #include <algorithm>
@@ -11,7 +9,7 @@
// vec // vec
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { struct ggml_tensor * llama_control_vector::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr; return nullptr;
} }
@@ -19,7 +17,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il]; return tensors[il];
} }
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il); ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) { if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir); cur = ggml_add(ctx, cur, layer_dir);
@@ -28,12 +26,12 @@ struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, str
return cur; return cur;
} }
bool llama_adapter_cvec::init(const llama_model & model) { static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
GGML_ASSERT(tensors.empty()); GGML_ASSERT(cvec.tensors.empty());
GGML_ASSERT(ctxs.empty()); GGML_ASSERT(cvec.ctxs.empty());
GGML_ASSERT(bufs.empty()); GGML_ASSERT(cvec.bufs.empty());
// create a context for each buffer type // create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -52,7 +50,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
} }
ctx_map[buft] = ctx; ctx_map[buft] = ctx;
ctxs.emplace_back(ctx); cvec.ctxs.emplace_back(ctx);
return ctx; return ctx;
} }
@@ -61,21 +59,21 @@ bool llama_adapter_cvec::init(const llama_model & model) {
}; };
// make tensors // make tensors
tensors.reserve(hparams.n_layer); cvec.tensors.reserve(hparams.n_layer);
tensors.push_back(nullptr); // there's never a tensor for layer 0 cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
for (size_t il = 1; il < hparams.n_layer; il++) { for (size_t il = 1; il < hparams.n_layer; il++) {
ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il);
ggml_context * ctx = ctx_for_buft(buft); ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) { if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
return false; return false;
} }
ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
tensors.push_back(tensor); cvec.tensors.push_back(tensor);
} }
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
bufs.reserve(ctx_map.size()); cvec.bufs.reserve(ctx_map.size());
for (auto it : ctx_map) { for (auto it : ctx_map) {
ggml_backend_buffer_type_t buft = it.first; ggml_backend_buffer_type_t buft = it.first;
ggml_context * ctx = it.second; ggml_context * ctx = it.second;
@@ -85,13 +83,14 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return false; return false;
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
bufs.emplace_back(buf); cvec.bufs.emplace_back(buf);
} }
return true; return true;
} }
int32_t llama_adapter_cvec::apply( int32_t llama_control_vector_apply(
struct llama_control_vector & cvec,
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@@ -102,8 +101,8 @@ int32_t llama_adapter_cvec::apply(
if (data == nullptr) { if (data == nullptr) {
// disable the current control vector (but leave allocated for later) // disable the current control vector (but leave allocated for later)
layer_start = -1; cvec.layer_start = -1;
layer_end = -1; cvec.layer_end = -1;
return 0; return 0;
} }
@@ -112,21 +111,21 @@ int32_t llama_adapter_cvec::apply(
return 1; return 1;
} }
if (tensors.empty()) { if (cvec.tensors.empty()) {
if (!init(model)) { if (!llama_control_vector_init(cvec, model)) {
return 1; return 1;
} }
} }
layer_start = il_start; cvec.layer_start = il_start;
layer_end = il_end; cvec.layer_end = il_end;
for (size_t il = 1; il < hparams.n_layer; il++) { for (size_t il = 1; il < hparams.n_layer; il++) {
assert(tensors[il] != nullptr); assert(cvec.tensors[il] != nullptr);
const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
if (off + n_embd <= len) { if (off + n_embd <= len) {
ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
} }
} }
@@ -135,7 +134,7 @@ int32_t llama_adapter_cvec::apply(
// lora // lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) {
const std::string name(w->name); const std::string name(w->name);
const auto pos = ab_map.find(name); const auto pos = ab_map.find(name);
@@ -146,7 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
return nullptr; return nullptr;
} }
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
delete adapter;
}
static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init; ggml_context * ctx_init;
@@ -218,7 +221,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
}; };
// bundle lora_a and lora_b into pairs // bundle lora_a and lora_b into pairs
std::map<std::string, llama_adapter_lora_weight> ab_map; std::map<std::string, llama_lora_weight> ab_map;
auto str_endswith = [](const std::string & str, const std::string & suffix) { auto str_endswith = [](const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}; };
@@ -228,21 +231,17 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
if (str_endswith(name, ".lora_a")) { if (str_endswith(name, ".lora_a")) {
replace_all(name, ".lora_a", ""); replace_all(name, ".lora_a", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_adapter_lora_weight(cur, nullptr); ab_map[name] = llama_lora_weight(cur, nullptr);
} else { } else {
ab_map[name].a = cur; ab_map[name].a = cur;
} }
} else if (str_endswith(name, ".lora_b")) { } else if (str_endswith(name, ".lora_b")) {
replace_all(name, ".lora_b", ""); replace_all(name, ".lora_b", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_adapter_lora_weight(nullptr, cur); ab_map[name] = llama_lora_weight(nullptr, cur);
} else { } else {
ab_map[name].b = cur; ab_map[name].b = cur;
} }
} else if (str_endswith(name, "_norm.weight")) {
// TODO: add support for norm vector
// for now, we don't really care because most adapters still work fine without it
continue;
} else { } else {
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
} }
@@ -251,33 +250,25 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
// add tensors // add tensors
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
llama_adapter_lora_weight & w = it.second; llama_lora_weight & w = it.second;
bool is_token_embd = str_endswith(name, "token_embd.weight");
if (!w.a || !w.b) { if (!w.a || !w.b) {
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
} }
// device buft and device ctx // device buft and device ctx
const auto * model_tensor = model.get_tensor(name.c_str()); auto * model_tensor = llama_model_get_tensor(model, name.c_str());
if (!model_tensor) { if (!model_tensor) {
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
} }
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape // validate tensor shape
if (is_token_embd) { if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() throw std::runtime_error("tensor '" + name + "' has incorrect shape");
if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { }
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); if (w.a->ne[1] != w.b->ne[0]) {
} throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
} else {
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
}
if (w.a->ne[1] != w.b->ne[0]) {
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
}
} }
// save tensor to adapter // save tensor to adapter
@@ -285,7 +276,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name); ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
} }
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
@@ -327,11 +318,11 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
} }
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
struct llama_adapter_lora * adapter = new llama_adapter_lora(); struct llama_lora_adapter * adapter = new llama_lora_adapter();
try { try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter); llama_lora_adapter_init_impl(*model, path_lora, *adapter);
return adapter; return adapter;
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
@@ -341,7 +332,3 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
return nullptr; return nullptr;
} }
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter;
}

View File

@@ -1,74 +1,66 @@
#pragma once #pragma once
#include "llama.h" #include "llama-impl.h"
#include "llama-hparams.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
// TODO: pimpl
// //
// llama_adapter_cvec // llama_adapter_cvec
// //
struct llama_adapter_cvec { // TODO: rename to llama_adapter_cvec
struct ggml_tensor * tensor_for(int il) const; struct llama_control_vector {
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
int32_t apply(
const llama_model & model,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
private:
bool init(const llama_model & model);
int32_t layer_start = -1;
int32_t layer_end = -1;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<struct ggml_tensor *> tensors; // per layer std::vector<struct ggml_tensor *> tensors; // per layer
int32_t layer_start = -1;
int32_t layer_end = -1;
struct ggml_tensor * tensor_for(int il) const;
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
}; };
int32_t llama_control_vector_apply(
struct llama_control_vector & cvec,
const llama_model & model,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
// //
// llama_adapter_lora // llama_adapter_lora
// //
struct llama_adapter_lora_weight { // TODO: rename to llama_adapter_lora_weight
struct llama_lora_weight {
struct ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha llama_lora_weight() = default;
float get_scale(float alpha, float adapter_scale) const { llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
const float rank = (float) b->ne[0];
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
return scale;
}
llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };
struct llama_adapter_lora { // TODO: rename to llama_adapter_lora
struct llama_lora_adapter {
// map tensor name to lora_a_b // map tensor name to lora_a_b
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map; std::unordered_map<std::string, struct llama_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
float alpha; float alpha;
llama_adapter_lora() = default; llama_lora_adapter() = default;
~llama_adapter_lora() = default; ~llama_lora_adapter() = default;
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); llama_lora_weight * get_weight(struct ggml_tensor * w);
}; };

View File

@@ -28,7 +28,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" }, { LLM_ARCH_ORION, "orion" },
@@ -37,7 +36,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
@@ -59,13 +57,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" }, { LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@@ -111,26 +107,25 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -184,8 +179,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@@ -629,27 +622,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_PHIMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{ {
LLM_ARCH_PLAMO, LLM_ARCH_PLAMO,
{ {
@@ -806,24 +778,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}, },
}, },
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
{ {
@@ -1082,9 +1036,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
@@ -1231,7 +1182,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
@@ -1249,32 +1199,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
}, },
}, },
{
LLM_ARCH_RWKV6QWEN2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
{ LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {
@@ -1329,24 +1253,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
}, },
}, },
{
LLM_ARCH_SOLAR,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_BSKCN_TV, "bskcn_tv" },
},
},
{ {
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
{ {
@@ -1373,20 +1279,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
}, },
}, },
{ {
LLM_ARCH_MISTRAL3, LLM_ARCH_SOLAR,
{ {
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
} { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_BSKCN_TV, "bskcn_tv" },
},
}, },
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
@@ -1491,7 +1399,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -1548,11 +1455,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
}; };
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_KV::operator()(llm_kv kv) const {
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
} }
std::string LLM_TN_IMPL::str() const { std::string LLM_TN_IMPL::str() const {

View File

@@ -32,7 +32,6 @@ enum llm_arch {
LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN2VL,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PHI3, LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO, LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL, LLM_ARCH_CODESHELL,
LLM_ARCH_ORION, LLM_ARCH_ORION,
@@ -41,7 +40,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3, LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA, LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
@@ -63,13 +61,11 @@ enum llm_arch {
LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR, LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@@ -115,7 +111,6 @@ enum llm_kv {
LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE, LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE, LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -182,8 +177,6 @@ enum llm_kv {
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_MID_ID,
@@ -263,7 +256,6 @@ enum llm_tensor {
LLM_TENSOR_TIME_MIX_LERP_V, LLM_TENSOR_TIME_MIX_LERP_V,
LLM_TENSOR_TIME_MIX_LERP_R, LLM_TENSOR_TIME_MIX_LERP_R,
LLM_TENSOR_TIME_MIX_LERP_G, LLM_TENSOR_TIME_MIX_LERP_G,
LLM_TENSOR_TIME_MIX_LERP_FUSED,
LLM_TENSOR_TIME_MIX_FIRST, LLM_TENSOR_TIME_MIX_FIRST,
LLM_TENSOR_TIME_MIX_DECAY, LLM_TENSOR_TIME_MIX_DECAY,
LLM_TENSOR_TIME_MIX_DECAY_W1, LLM_TENSOR_TIME_MIX_DECAY_W1,
@@ -351,10 +343,9 @@ enum llm_tensor_layer {
}; };
struct LLM_KV { struct LLM_KV {
LLM_KV(llm_arch arch, const char * suffix = nullptr); LLM_KV(llm_arch arch);
llm_arch arch; llm_arch arch;
const char * suffix;
std::string operator()(llm_kv kv) const; std::string operator()(llm_kv kv) const;
}; };

View File

@@ -35,7 +35,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH }, { "monarch", LLM_CHAT_TEMPLATE_MONARCH },
@@ -51,7 +50,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
@@ -75,9 +73,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return tmpl.find(haystack) != std::string::npos; return tmpl.find(haystack) != std::string::npos;
}; };
if (tmpl_contains("<|im_start|>")) { if (tmpl_contains("<|im_start|>")) {
return tmpl_contains("<|im_sep|>") return LLM_CHAT_TEMPLATE_CHATML;
? LLM_CHAT_TEMPLATE_PHI_4
: LLM_CHAT_TEMPLATE_CHATML;
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
if (tmpl_contains("[SYSTEM_PROMPT]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) {
return LLM_CHAT_TEMPLATE_MISTRAL_V7; return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@@ -116,7 +112,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3; return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; return LLM_CHAT_TEMPLATE_FALCON_3;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR; return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) { } else if (tmpl_contains("bos_token + message['role']")) {
@@ -153,7 +149,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_MINICPM; return LLM_CHAT_TEMPLATE_MINICPM;
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_2; return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
} else if (tmpl_contains(LU8("<Assistant>")) && tmpl_contains(LU8("<User>")) && tmpl_contains(LU8("<end▁of▁sentence>"))) { } else if (tmpl_contains(LU8("'<Assistant>' + message['content'] + '<end▁of▁sentence>'"))) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_3; return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
@@ -273,14 +269,6 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>\n"; ss << "<|assistant|>\n";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>";
}
if (add_ass) {
ss << "<|im_start|>assistant<|im_sep|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
// Falcon 3 // Falcon 3
for (auto message : chat) { for (auto message : chat) {
@@ -441,14 +429,6 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) { for (auto message : chat) {

View File

@@ -15,7 +15,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_FALCON_3,
LLM_CHAT_TEMPLATE_ZEPHYR, LLM_CHAT_TEMPLATE_ZEPHYR,
LLM_CHAT_TEMPLATE_MONARCH, LLM_CHAT_TEMPLATE_MONARCH,
@@ -31,7 +30,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGML_3, LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGML_4, LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_RWKV_WORLD,

View File

@@ -1,8 +1,5 @@
#include "llama-context.h" #include "llama-context.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@@ -516,7 +513,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
auto * buft = ggml_backend_cpu_buffer_type(); auto * buft = ggml_backend_cpu_buffer_type();
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
auto * output_dev = lctx.model.dev_output(); auto * output_dev = lctx.model.dev_output.dev;
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
if (output_dev_host_buft) { if (output_dev_host_buft) {
buft = output_dev_host_buft; buft = output_dev_host_buft;

View File

@@ -22,12 +22,12 @@ struct llama_context {
const struct llama_model & model; const struct llama_model & model;
struct llama_cparams cparams; struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec; struct llama_control_vector cvec;
std::unordered_map<struct llama_adapter_lora *, float> lora; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_ptr> backends; std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;

View File

@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size(); size_t last_sym_start = rule.size();
const char * pos = src; const char * pos = src;
auto handle_repetitions = [&](int min_times, int max_times) { auto handle_repetitions = [&](int min_times, int max_times) {
if (last_sym_start == rule.size()) { if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
} }
}
uint32_t last_rec_rule_id = 0; // apply transformation to previous symbol (last_sym_start to end) according to
auto n_opt = max_times < 0 ? 1 : max_times - min_times; // the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
llama_grammar_rule rec_rule(prev_rule); llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
for (int i = 0; i < n_opt; i++) { if (min_times == 0) {
rec_rule.resize(prev_rule.size()); rule.resize(last_sym_start);
uint32_t rec_rule_id = generate_symbol_id( rule_name); } else {
if (i > 0 || max_times < 0) { // Repeat the previous elements (min_times - 1) times
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); for (int i = 1; i < min_times; i++) {
} rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
} }
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
} }
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s) uint32_t last_rec_rule_id = 0;
pos++; auto n_opt = max_times < 0 ? 1 : max_times - min_times;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') { llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++; pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT; last_sym_start = rule.size();
} while (*pos != '"') {
last_sym_start = rule.size(); if (!*pos) {
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input"); throw std::runtime_error("unexpected end of input");
} }
auto endchar_pair = parse_char(pos + 1); auto char_pair = parse_char(pos);
pos = endchar_pair.second; pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
} }
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else { } else {
throw std::runtime_error(std::string("expecting ',' at ") + pos); break;
} }
handle_repetitions(min_times, max_times);
} else {
break;
} }
return pos;
} }
return pos;
}
const char * llama_grammar_parser::parse_rule(const char * src) { const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src); const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false); const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src; size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len); uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len); const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos); throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
} }
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
bool llama_grammar_parser::parse(const char * src) { bool llama_grammar_parser::parse(const char * src) {
try { try {
@@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
} }
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
rules.clear(); rules.clear();
return false; return false;
} }
@@ -960,28 +960,10 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .lazy =*/ false,
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
/* .trigger_words = */ {},
};
} }
struct llama_grammar * llama_grammar_init_impl( struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
llama_grammar_parser parser; llama_grammar_parser parser;
// if there is a grammar, parse it // if there is a grammar, parse it
@@ -1053,31 +1035,10 @@ struct llama_grammar * llama_grammar_init_impl(
} }
} while (true); } while (true);
std::vector<llama_token> vec_trigger_tokens;
std::vector<std::string> vec_trigger_words;
for (size_t i = 0; i < num_trigger_tokens; i++) {
GGML_ASSERT(trigger_tokens != nullptr);
vec_trigger_tokens.push_back(trigger_tokens[i]);
}
for (size_t i = 0; i < num_trigger_words; i++) {
GGML_ASSERT(trigger_words != nullptr);
vec_trigger_words.push_back(trigger_words[i]);
}
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
std::move(vec_trigger_tokens),
std::move(vec_trigger_words),
};
} }
void llama_grammar_free_impl(struct llama_grammar * grammar) { void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1094,11 +1055,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.rules, grammar.rules,
grammar.stacks, grammar.stacks,
grammar.partial_utf8, grammar.partial_utf8,
grammar.lazy,
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_tokens,
grammar.trigger_words,
}; };
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
@@ -1120,10 +1076,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
GGML_ASSERT(grammar.vocab != nullptr); GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
return;
}
bool allow_eog = false; bool allow_eog = false;
for (const auto & stack : grammar.stacks) { for (const auto & stack : grammar.stacks) {
if (stack.empty()) { if (stack.empty()) {
@@ -1140,9 +1092,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id; const llama_token id = cur_p->data[i].id;
const std::string & piece = grammar.vocab->token_to_piece(id); const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
if (grammar.vocab->is_eog(id)) { if (llama_token_is_eog_impl(*grammar.vocab, id)) {
if (!allow_eog) { if (!allow_eog) {
cur_p->data[i].logit = -INFINITY; cur_p->data[i].logit = -INFINITY;
} }
@@ -1163,35 +1115,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr); GGML_ASSERT(grammar.vocab != nullptr);
const auto & piece = grammar.vocab->token_to_piece(token); if (llama_token_is_eog_impl(*grammar.vocab, token)) {
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, piece);
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return;
} else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += piece;
for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) {
grammar.awaiting_trigger = false;
auto constrained_str = grammar.trigger_buffer.substr(pos);
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
return;
}
}
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
return;
}
}
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) { for (const auto & stack : grammar.stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@@ -1200,10 +1124,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
llama_grammar_accept_str(grammar, piece); const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
}
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
@@ -1213,7 +1135,5 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
} }
grammar.partial_utf8 = decoded.second; grammar.partial_utf8 = decoded.second;
if (grammar.stacks.empty()) { GGML_ASSERT(!grammar.stacks.empty());
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
} }

View File

@@ -114,15 +114,6 @@ struct llama_grammar {
// buffer for partially generated UTF-8 sequence from accepted tokens // buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8; llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required)
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<std::string> trigger_words;
}; };
// //
@@ -136,15 +127,7 @@ struct llama_grammar * llama_grammar_init_impl(
size_t n_rules, size_t n_rules,
size_t start_rule_index); size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl( struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
void llama_grammar_free_impl(struct llama_grammar * grammar); void llama_grammar_free_impl(struct llama_grammar * grammar);
@@ -158,7 +141,3 @@ void llama_grammar_apply_impl(
void llama_grammar_accept_impl( void llama_grammar_accept_impl(
struct llama_grammar & grammar, struct llama_grammar & grammar,
llama_token token); llama_token token);
void llama_grammar_accept_str(
struct llama_grammar & grammar,
const std::string & piece);

View File

@@ -54,7 +54,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
uint32_t llama_hparams::n_embd_k_s() const { uint32_t llama_hparams::n_embd_k_s() const {
if (wkv_head_size != 0) { if (wkv_head_size != 0) {
// for RWKV models // for RWKV models
return token_shift_count * n_embd; return 2 * n_embd;
} }
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
@@ -82,4 +82,4 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
bool llama_hparams::cross_attention_layers(uint32_t il) const { bool llama_hparams::cross_attention_layers(uint32_t il) const {
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
} }

View File

@@ -30,6 +30,7 @@ struct llama_hparams {
bool use_par_res; bool use_par_res;
bool swin_norm; bool swin_norm;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd; uint32_t n_embd;
uint32_t n_embd_features = 0; uint32_t n_embd_features = 0;
@@ -40,8 +41,8 @@ struct llama_hparams {
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0; uint32_t n_expert = 0;
uint32_t n_expert_used = 0; uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0; uint32_t n_rel_attn_bkts = 0;
uint32_t n_vocab = 0;
// for WavTokenizer // for WavTokenizer
struct llama_hparams_posnet posnet; struct llama_hparams_posnet posnet;
@@ -78,7 +79,6 @@ struct llama_hparams {
uint32_t time_mix_extra_dim = 0; uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0; uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0; uint32_t wkv_head_size = 0;
uint32_t token_shift_count = 2;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;
@@ -141,7 +141,7 @@ struct llama_hparams {
// Block skip connection // Block skip connection
bool n_bskcn(uint32_t n, uint32_t il) const; bool n_bskcn(uint32_t n, uint32_t il) const;
// cross attention layers // cross attention layers
bool cross_attention_layers(uint32_t il) const; bool cross_attention_layers(uint32_t il) const;
}; };

View File

@@ -1,6 +1,5 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "gguf.h"
#include "llama.h" #include "llama.h"
#include <cinttypes> #include <cinttypes>
@@ -139,7 +138,7 @@ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
{ {
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
int arr_n = gguf_get_arr_n(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i);
const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); const void * data = gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
for (int j = 0; j < arr_n; j++) { for (int j = 0; j < arr_n; j++) {

View File

@@ -6,13 +6,13 @@
#include <vector> #include <vector>
#ifdef __GNUC__ #ifdef __GNUC__
# if defined(__MINGW32__) && !defined(__clang__) #ifdef __MINGW32__
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else #else
# define LLAMA_ATTRIBUTE_FORMAT(...) #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_ATTRIBUTE_FORMAT(...)
#endif #endif
// //

View File

@@ -72,6 +72,39 @@ bool llama_kv_cache_init(
cache.v_l.reserve(n_layer); cache.v_l.reserve(n_layer);
for (int i = 0; i < n_layer; i++) { for (int i = 0; i < n_layer; i++) {
// for cross attention layers
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const llama_model::buft_list_t * buft_list;
if (offload) {
buft_list = model.dev_layer.at(i).buft_list;
} else {
buft_list = &model.cpu_buft_list;
}
ggml_backend_buffer_type_t buft = select_buft(*buft_list,
[&](ggml_context * ctx) {
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
return k;
}
ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
});
ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
return false;
}
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
continue;
}
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -79,7 +112,7 @@ bool llama_kv_cache_init(
ggml_backend_buffer_type_t buft; ggml_backend_buffer_type_t buft;
if (offload) { if (offload) {
auto * dev = model.dev_layer(i); auto * dev = model.dev_layer.at(i).dev;
buft = ggml_backend_dev_buffer_type(dev); buft = ggml_backend_dev_buffer_type(dev);
} else { } else {
buft = ggml_backend_cpu_buffer_type(); buft = ggml_backend_cpu_buffer_type();
@@ -91,17 +124,8 @@ bool llama_kv_cache_init(
return false; return false;
} }
ggml_tensor * k, *v; ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
// for cross attention layers
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
} else {
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
}
ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i); ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k); cache.k_l.push_back(k);
@@ -128,10 +152,10 @@ bool llama_kv_cache_init(
struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
const struct llama_ubatch & ubatch) { const struct llama_ubatch & batch) {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seqs = batch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens; const uint32_t n_seq_tokens = batch.n_seq_tokens;
if (cache.recurrent) { if (cache.recurrent) {
// For recurrent state architectures (like Mamba or RWKV), // For recurrent state architectures (like Mamba or RWKV),
@@ -139,16 +163,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
// A slot should be always be contiguous. // A slot should be always be contiguous.
// can only process batches with an equal number of new tokens in each sequence // can only process batches with an equal number of new tokens in each sequence
GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(batch.equal_seqs);
int32_t min = cache.size - 1; int32_t min = cache.size - 1;
int32_t max = 0; int32_t max = 0;
// everything should fit if all seq_ids are smaller than the max // everything should fit if all seq_ids are smaller than the max
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t n_seq_id = ubatch.n_seq_id[s]; const uint32_t n_seq_id = batch.n_seq_id[s];
for (uint32_t j = 0; j < n_seq_id; ++j) { for (uint32_t j = 0; j < n_seq_id; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j]; const llama_seq_id seq_id = batch.seq_id[s][j];
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
// too big seq_id // too big seq_id
@@ -207,7 +231,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
// find usable cell range // find usable cell range
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const llama_seq_id seq_id = batch.seq_id[s][0];
llama_kv_cell & seq_meta = cache.cells[seq_id]; llama_kv_cell & seq_meta = cache.cells[seq_id];
bool has_cell = false; bool has_cell = false;
if (seq_meta.tail >= 0) { if (seq_meta.tail >= 0) {
@@ -246,7 +270,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
// gather and re-order // gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
int32_t dst_id = s + min; int32_t dst_id = s + min;
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
if (dst_id != src_id) { if (dst_id != src_id) {
llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & dst_cell = cache.cells[dst_id];
llama_kv_cell & src_cell = cache.cells[src_id]; llama_kv_cell & src_cell = cache.cells[src_id];
@@ -267,7 +291,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
// update the pos of the used seqs // update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
int32_t cell_id = s + min; int32_t cell_id = s + min;
llama_kv_cell & cell = cache.cells[cell_id]; llama_kv_cell & cell = cache.cells[cell_id];
@@ -275,12 +299,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
// What should happen when the pos backtracks or skips a value? // What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done. // Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
} }
cell.pos = last_pos; cell.pos = last_pos;
cell.seq_id.clear(); cell.seq_id.clear();
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j]; const llama_seq_id seq_id = batch.seq_id[s][j];
cell.seq_id.insert(seq_id); cell.seq_id.insert(seq_id);
cache.cells[seq_id].tail = cell_id; cache.cells[seq_id].tail = cell_id;
} }
@@ -334,10 +358,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t s = 0; s < n_seqs; s++) {
for (uint32_t i = 0; i < n_seq_tokens; ++i) { for (uint32_t i = 0; i < n_seq_tokens; ++i) {
uint32_t k = s*n_seq_tokens + i; uint32_t k = s*n_seq_tokens + i;
cache.cells[cache.head + k].pos = ubatch.pos[k]; cache.cells[cache.head + k].pos = batch.pos[k];
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
} }
} }
} }

View File

@@ -37,7 +37,7 @@ struct llama_kv_cache {
bool can_shift = false; bool can_shift = false;
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it // for a free KV slot. llama_decode_internal also uses it, so it
// cannot be freely changed after a slot has been allocated. // cannot be freely changed after a slot has been allocated.
uint32_t head = 0; uint32_t head = 0;
uint32_t size = 0; uint32_t size = 0;

View File

@@ -7,7 +7,6 @@
#include <cstring> #include <cstring>
#include <climits> #include <climits>
#include <stdexcept> #include <stdexcept>
#include <cerrno>
#ifdef __has_include #ifdef __has_include
#if __has_include(<unistd.h>) #if __has_include(<unistd.h>)
@@ -36,7 +35,7 @@
// TODO: consider moving to llama-impl.h if needed in more places // TODO: consider moving to llama-impl.h if needed in more places
#if defined(_WIN32) #if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) { std::string llama_format_win_err(DWORD err) {
LPSTR buf; LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
@@ -242,16 +241,12 @@ llama_file::~llama_file() = default;
size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::tell() const { return pimpl->tell(); }
size_t llama_file::size() const { return pimpl->size; } size_t llama_file::size() const { return pimpl->size; }
int llama_file::file_id() const { int llama_file::fileno() const {
#ifdef _WIN32 #ifdef _WIN32
return _fileno(pimpl->fp); return _fileno(pimpl->fp);
#else
#if defined(fileno)
return fileno(pimpl->fp);
#else #else
return ::fileno(pimpl->fp); return ::fileno(pimpl->fp);
#endif #endif
#endif
} }
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
@@ -270,7 +265,7 @@ struct llama_mmap::impl {
impl(struct llama_file * file, size_t prefetch, bool numa) { impl(struct llama_file * file, size_t prefetch, bool numa) {
size = file->size(); size = file->size();
int fd = file->file_id(); int fd = file->fileno();
int flags = MAP_SHARED; int flags = MAP_SHARED;
if (numa) { prefetch = 0; } if (numa) { prefetch = 0; }
#ifdef __linux__ #ifdef __linux__
@@ -362,7 +357,7 @@ struct llama_mmap::impl {
size = file->size(); size = file->size();
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); HANDLE hFile = (HANDLE) _get_osfhandle(file->fileno());
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);

View File

@@ -1,6 +1,5 @@
#pragma once #pragma once
#include <cstdint>
#include <memory> #include <memory>
#include <vector> #include <vector>
@@ -19,7 +18,7 @@ struct llama_file {
size_t tell() const; size_t tell() const;
size_t size() const; size_t size() const;
int file_id() const; // fileno overload int fileno() const;
void seek(size_t offset, int whence) const; void seek(size_t offset, int whence) const;

View File

@@ -7,10 +7,6 @@
#include <cstring> #include <cstring>
#include <future> #include <future>
static const size_t kiB = 1024;
static const size_t MiB = 1024*kiB;
static const size_t GiB = 1024*MiB;
const char * llama_file_version_name(llama_fver version) { const char * llama_file_version_name(llama_fver version) {
switch (version) { switch (version) {
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
@@ -21,78 +17,8 @@ const char * llama_file_version_name(llama_fver version) {
return "unknown"; return "unknown";
} }
static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary";
case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
default: return "unknown, may not work";
}
}
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
std::vector<std::string> paths;
std::string split_prefix;
std::vector<char> buf(llama_path_max(), 0);
{
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
if (!ret) {
throw std::runtime_error(format("invalid split file name: %s", path.c_str()));
}
split_prefix = std::string(buf.data(), ret);
}
if (split_prefix.empty()) {
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
}
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
paths.push_back(std::string(buf.data(), ret));
}
return paths;
}
namespace GGUFMeta { namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)> template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
struct GKV_Base_Type { struct GKV_Base_Type {
static constexpr gguf_type gt = gt_; static constexpr gguf_type gt = gt_;
@@ -134,11 +60,10 @@ namespace GGUFMeta {
public: public:
static constexpr gguf_type gt = GGUF_TYPE_ARRAY; static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
static ArrayInfo getter(const gguf_context *ctx, const int k) { static ArrayInfo getter(const gguf_context *ctx, const int k) {
const enum gguf_type arr_type = gguf_get_arr_type(ctx, k);
return ArrayInfo { return ArrayInfo {
arr_type, gguf_get_arr_type(ctx, k),
size_t(gguf_get_arr_n(ctx, k)), size_t(gguf_get_arr_n(ctx, k)),
arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), gguf_get_arr_data(ctx, k),
}; };
} }
}; };
@@ -443,12 +368,7 @@ namespace GGUFMeta {
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
template bool llama_model_loader::get_key_or_arr<uint32_t>(const std::string & key, std::array<uint32_t, 512> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<uint32_t>(const std::string & key, std::array<uint32_t, 512> & result, uint32_t n, bool required);
llama_model_loader::llama_model_loader( llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
const std::string & fname,
std::vector<std::string> & splits,
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p) {
int trace = 0; int trace = 0;
if (getenv("LLAMA_TRACE")) { if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE")); trace = atoi(getenv("LLAMA_TRACE"));
@@ -460,7 +380,6 @@ llama_model_loader::llama_model_loader(
} }
} }
// Load the main GGUF
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
/*.no_alloc = */ true, /*.no_alloc = */ true,
@@ -496,54 +415,35 @@ llama_model_loader::llama_model_loader(
// Load additional GGML contexts // Load additional GGML contexts
if (n_split > 1) { if (n_split > 1) {
// make sure the main file is loaded first
uint16_t idx = 0; uint16_t idx = 0;
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
get_key(kv_split_no, idx);
if (idx != 0) { if (idx != 0) {
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
} }
// generate list of splits if needed std::vector<char> split_prefix(llama_path_max(), 0);
if (splits.empty()) { if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
splits = llama_get_list_splits(fname, idx, n_split); throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
}
// in case user give a custom list of splits, check if it matches the expected number
if (n_split != (uint16_t)splits.size()) {
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
} }
if (trace > 0) { if (trace > 0) {
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
} }
// load other splits std::vector<char> split_path(llama_path_max(), 0);
for (idx = 1; idx < n_split; idx++) { for (idx = 1; idx < n_split; idx++) {
const char * fname_split = splits[idx].c_str(); llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
struct gguf_init_params split_params = { struct gguf_init_params split_params = {
/*.no_alloc = */ true, /*.no_alloc = */ true,
/*.ctx = */ &ctx, /*.ctx = */ &ctx,
}; };
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
if (!ctx_gguf) { if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
} }
// check idx files.emplace_back(new llama_file(split_path.data(), "rb"));
{
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
if (kid < 0) {
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
}
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
if (idx_gguf != idx) {
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
}
}
files.emplace_back(new llama_file(fname_split, "rb"));
contexts.emplace_back(ctx); contexts.emplace_back(ctx);
// Save tensors data offset info of the shard. // Save tensors data offset info of the shard.
@@ -656,7 +556,7 @@ llama_model_loader::llama_model_loader(
const enum gguf_type type = gguf_get_kv_type(meta.get(), i); const enum gguf_type type = gguf_get_kv_type(meta.get(), i);
const std::string type_name = const std::string type_name =
type == GGUF_TYPE_ARRAY type == GGUF_TYPE_ARRAY
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
: gguf_type_name(type); : gguf_type_name(type);
std::string value = gguf_kv_to_str(meta.get(), i); std::string value = gguf_kv_to_str(meta.get(), i);
@@ -822,7 +722,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
for (const auto & file : files) { for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn()); std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
mmaps_used.emplace_back(mapping->size(), 0); mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) { if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock()); std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
@@ -1111,17 +1011,3 @@ bool llama_model_loader::load_all_data(
return true; return true;
} }
std::string llama_model_loader::ftype_name() const {
return llama_model_ftype_name(ftype);
}
void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver));
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str());
if (n_bytes < GiB) {
LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements);
} else {
LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements);
}
}

View File

@@ -90,12 +90,7 @@ struct llama_model_loader {
size_t size_data = 0; size_t size_data = 0;
std::vector<std::pair<size_t, size_t>> mmaps_used; std::vector<std::pair<size_t, size_t>> mmaps_used;
llama_model_loader( llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
const std::string & fname,
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p);
template<typename T> template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type typename std::enable_if<std::is_integral<T>::value, bool>::type
@@ -160,8 +155,4 @@ struct llama_model_loader {
llama_mlocks * lmlocks, llama_mlocks * lmlocks,
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void * progress_callback_user_data); void * progress_callback_user_data);
std::string ftype_name() const;
void print_info() const;
}; };

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,83 +4,81 @@
#include "llama-arch.h" #include "llama-arch.h"
#include "llama-hparams.h" #include "llama-hparams.h"
#include "llama-vocab.h" #include "llama-vocab.h"
#include "llama-mmap.h"
#include "ggml-cpp.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include <stdexcept> #include <stdexcept>
struct llama_model_loader;
// available models // available models
// TODO: this enum does not follow the enum naming convention
enum llm_type { enum llm_type {
LLM_TYPE_UNKNOWN, MODEL_UNKNOWN,
LLM_TYPE_14M, MODEL_14M,
LLM_TYPE_17M, MODEL_17M,
LLM_TYPE_22M, MODEL_22M,
LLM_TYPE_33M, MODEL_33M,
LLM_TYPE_60M, MODEL_60M,
LLM_TYPE_70M, MODEL_70M,
LLM_TYPE_80M, MODEL_80M,
LLM_TYPE_109M, MODEL_109M,
LLM_TYPE_137M, MODEL_137M,
LLM_TYPE_160M, MODEL_160M,
LLM_TYPE_220M, MODEL_220M,
LLM_TYPE_250M, MODEL_250M,
LLM_TYPE_270M, MODEL_270M,
LLM_TYPE_335M, MODEL_335M,
LLM_TYPE_410M, MODEL_410M,
LLM_TYPE_450M, MODEL_450M,
LLM_TYPE_770M, MODEL_770M,
LLM_TYPE_780M, MODEL_780M,
LLM_TYPE_0_5B, MODEL_0_5B,
LLM_TYPE_1B, MODEL_1B,
LLM_TYPE_1_3B, MODEL_1_3B,
LLM_TYPE_1_4B, MODEL_1_4B,
LLM_TYPE_1_5B, MODEL_1_5B,
LLM_TYPE_1_6B, MODEL_1_6B,
LLM_TYPE_2B, MODEL_2B,
LLM_TYPE_2_8B, MODEL_2_8B,
LLM_TYPE_3B, MODEL_3B,
LLM_TYPE_4B, MODEL_4B,
LLM_TYPE_6B, MODEL_6B,
LLM_TYPE_6_9B, MODEL_6_9B,
LLM_TYPE_7B, MODEL_7B,
LLM_TYPE_8B, MODEL_8B,
LLM_TYPE_9B, MODEL_9B,
LLM_TYPE_11B, MODEL_11B,
LLM_TYPE_12B, MODEL_12B,
LLM_TYPE_13B, MODEL_13B,
LLM_TYPE_14B, MODEL_14B,
LLM_TYPE_15B, MODEL_15B,
LLM_TYPE_16B, MODEL_16B,
LLM_TYPE_20B, MODEL_20B,
LLM_TYPE_22B, MODEL_22B,
LLM_TYPE_30B, MODEL_30B,
LLM_TYPE_32B, MODEL_32B,
LLM_TYPE_34B, MODEL_34B,
LLM_TYPE_35B, MODEL_35B,
LLM_TYPE_40B, MODEL_40B,
LLM_TYPE_65B, MODEL_65B,
LLM_TYPE_70B, MODEL_70B,
LLM_TYPE_90B, MODEL_90B,
LLM_TYPE_236B, MODEL_236B,
LLM_TYPE_314B, MODEL_314B,
LLM_TYPE_671B, MODEL_671B,
LLM_TYPE_SMALL, MODEL_SMALL,
LLM_TYPE_MEDIUM, MODEL_MEDIUM,
LLM_TYPE_LARGE, MODEL_LARGE,
LLM_TYPE_XL, MODEL_XL,
LLM_TYPE_A1_7B, MODEL_A1_7B,
LLM_TYPE_A2_7B, MODEL_A2_7B,
LLM_TYPE_8x7B, MODEL_8x7B,
LLM_TYPE_8x22B, MODEL_8x22B,
LLM_TYPE_16x12B, MODEL_16x12B,
LLM_TYPE_16x3_8B, MODEL_10B_128x3_66B,
LLM_TYPE_10B_128x3_66B, MODEL_57B_A14B,
LLM_TYPE_57B_A14B, MODEL_27B,
LLM_TYPE_27B,
}; };
struct llama_layer_posnet { struct llama_layer_posnet {
@@ -245,19 +243,15 @@ struct llama_layer {
struct ggml_tensor * time_mix_lerp_v = nullptr; struct ggml_tensor * time_mix_lerp_v = nullptr;
struct ggml_tensor * time_mix_lerp_r = nullptr; struct ggml_tensor * time_mix_lerp_r = nullptr;
struct ggml_tensor * time_mix_lerp_g = nullptr; struct ggml_tensor * time_mix_lerp_g = nullptr;
struct ggml_tensor * time_mix_lerp_fused = nullptr;
struct ggml_tensor * time_mix_first = nullptr; struct ggml_tensor * time_mix_first = nullptr;
struct ggml_tensor * time_mix_decay = nullptr; struct ggml_tensor * time_mix_decay = nullptr;
struct ggml_tensor * time_mix_decay_w1 = nullptr; struct ggml_tensor * time_mix_decay_w1 = nullptr;
struct ggml_tensor * time_mix_decay_w2 = nullptr; struct ggml_tensor * time_mix_decay_w2 = nullptr;
struct ggml_tensor * time_mix_key = nullptr; struct ggml_tensor * time_mix_key = nullptr;
struct ggml_tensor * time_mix_key_b = nullptr; struct ggml_tensor * time_mix_value = nullptr;
struct ggml_tensor * time_mix_value = nullptr; struct ggml_tensor * time_mix_receptance = nullptr;
struct ggml_tensor * time_mix_value_b = nullptr; struct ggml_tensor * time_mix_gate = nullptr;
struct ggml_tensor * time_mix_receptance = nullptr;
struct ggml_tensor * time_mix_receptance_b = nullptr;
struct ggml_tensor * time_mix_gate = nullptr;
struct ggml_tensor * time_mix_ln = nullptr; struct ggml_tensor * time_mix_ln = nullptr;
struct ggml_tensor * time_mix_ln_b = nullptr; struct ggml_tensor * time_mix_ln_b = nullptr;
@@ -286,7 +280,7 @@ struct llama_layer {
struct ggml_tensor * bskcn_tv = nullptr; struct ggml_tensor * bskcn_tv = nullptr;
// cross attention // cross attention
struct ggml_tensor * cross_attn_k_norm = nullptr; struct ggml_tensor * cross_attn_k_norm = nullptr;
struct ggml_tensor * cross_attn_k_proj = nullptr; struct ggml_tensor * cross_attn_k_proj = nullptr;
struct ggml_tensor * cross_attn_o_proj = nullptr; struct ggml_tensor * cross_attn_o_proj = nullptr;
@@ -302,9 +296,11 @@ struct llama_layer {
}; };
struct llama_model { struct llama_model {
llm_type type = LLM_TYPE_UNKNOWN; llm_type type = MODEL_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a"; std::string name = "n/a";
llama_hparams hparams = {}; llama_hparams hparams = {};
@@ -333,53 +329,117 @@ struct llama_model {
std::vector<llama_layer> layers; std::vector<llama_layer> layers;
llama_model_params params;
// gguf metadata // gguf metadata
std::unordered_map<std::string, std::string> gguf_kv; std::unordered_map<std::string, std::string> gguf_kv;
llama_split_mode split_mode;
int main_gpu;
int n_gpu_layers;
std::vector<std::string> rpc_servers;
// list of devices used in this model // list of devices used in this model
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
// lists of buffer types used for each layer
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
struct layer_dev {
ggml_backend_dev_t dev;
buft_list_t * buft_list;
};
layer_dev dev_input = {};
layer_dev dev_output = {};
std::vector<layer_dev> dev_layer;
// contexts where the model tensors metadata is stored
std::vector<ggml_context_ptr> ctxs;
// the model memory buffers for the tensor data
std::vector<ggml_backend_buffer_ptr> bufs;
// model memory mapped files
llama_mmaps mappings;
// objects representing data potentially being locked in memory
llama_mlocks mlock_bufs;
llama_mlocks mlock_mmaps;
// for quantize-stats only // for quantize-stats only
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name; std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
int64_t t_load_us = 0; int64_t t_load_us = 0;
int64_t t_start_us = 0; int64_t t_start_us = 0;
explicit llama_model(const struct llama_model_params & params);
~llama_model();
void load_stats (llama_model_loader & ml);
void load_arch (llama_model_loader & ml);
void load_hparams(llama_model_loader & ml);
void load_vocab (llama_model_loader & ml);
bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
std::string arch_name() const;
std::string type_name() const;
std::string desc() const;
size_t size() const;
size_t max_nodes() const;
size_t n_devices() const;
// total number of parameters in the model // total number of parameters in the model
uint64_t n_elements() const; uint64_t n_elements = 0;
void print_info() const; // total size of all the tensors in the model in bytes
size_t n_bytes = 0;
ggml_backend_dev_t dev_layer(int il) const;
ggml_backend_dev_t dev_output() const;
ggml_backend_buffer_type_t select_buft(int il) const;
const struct ggml_tensor * get_tensor(const char * name) const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
const char * llm_type_name(llm_type type); const char * llm_type_name(llm_type type);
std::string llama_model_arch_name (const llama_model & model);
std::string llama_model_type_name (const llama_model & model);
std::string llama_model_ftype_name(const llama_model & model);
template<typename F>
bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx { ggml_init(params) };
if (!ctx) {
throw std::runtime_error("failed to create ggml context");
}
ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
ggml_tensor * op_tensor = fn(ctx.get());
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op_tensor->src[i] != nullptr) {
op_tensor->src[i]->buffer = buf.get();
}
}
bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
return op_supported;
}
template<typename F>
ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
for (const auto & cur : buft_list) {
ggml_backend_dev_t cur_dev = cur.first;
ggml_backend_buffer_type_t cur_buft = cur.second;
if (buft_supported(cur_buft, cur_dev, fn)) {
return cur_buft;
}
}
throw std::runtime_error("no suitable buffer type found");
}
// used by llama_adapter_cvec
ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
// used by llama_adapter_lora
struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
size_t llama_model_max_nodes(const llama_model & model);
struct llama_model_loader;
// TODO: become llama_model methods
void llm_load_stats (llama_model_loader & ml, llama_model & model);
void llm_load_arch (llama_model_loader & ml, llama_model & model);
void llm_load_hparams (llama_model_loader & ml, llama_model & model);
void llm_load_vocab (llama_model_loader & ml, llama_model & model);
void llm_load_print_meta(llama_model_loader & ml, llama_model & model);

View File

@@ -7,12 +7,14 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <cinttypes>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
// TODO: replace with ggml API call
#define QK_K 256
static void zeros(std::ofstream & file, size_t n) { static void zeros(std::ofstream & file, size_t n) {
char zero = 0; char zero = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
@@ -20,7 +22,7 @@ static void zeros(std::ofstream & file, size_t n) {
} }
} }
struct quantize_state_impl { struct quantize_state_internal {
const llama_model & model; const llama_model & model;
const llama_model_quantize_params * params; const llama_model_quantize_params * params;
@@ -41,13 +43,13 @@ struct quantize_state_impl {
// used to figure out if a model shares tok_embd with the output weight // used to figure out if a model shares tok_embd with the output weight
bool has_output = false; bool has_output = false;
quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
: model(model) : model(model)
, params(params) , params(params)
{} {}
}; };
static void llama_tensor_dequantize_impl( static void llama_tensor_dequantize_internal(
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers, struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
const size_t nelements, const int nthread const size_t nelements, const int nthread
) { ) {
@@ -119,7 +121,7 @@ static void llama_tensor_dequantize_impl(
workers.clear(); workers.clear();
} }
static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
const std::string name = ggml_get_name(tensor); const std::string name = ggml_get_name(tensor);
// TODO: avoid hardcoded tensor names - use the TN_* constants // TODO: avoid hardcoded tensor names - use the TN_* constants
@@ -152,10 +154,8 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
new_type = qs.params->output_tensor_type; new_type = qs.params->output_tensor_type;
} else { } else {
const int64_t nx = tensor->ne[0]; int nx = tensor->ne[0];
const int64_t qk_k = ggml_blck_size(new_type); if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
new_type = GGML_TYPE_Q8_0; new_type = GGML_TYPE_Q8_0;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
@@ -235,7 +235,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
if (qs.model.type == LLM_TYPE_70B) { if (qs.model.type == MODEL_70B) {
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
// nearly negligible increase in model size by quantizing this tensor with more bits: // nearly negligible increase in model size by quantizing this tensor with more bits:
@@ -367,19 +367,20 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
//} //}
bool convert_incompatible_tensor = false; bool convert_incompatible_tensor = false;
{ if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
const int64_t nx = tensor->ne[0]; new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
const int64_t ny = tensor->ne[1]; new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
const int64_t qk_k = ggml_blck_size(new_type); new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
new_type == GGML_TYPE_IQ1_M) {
if (nx % qk_k != 0) { int nx = tensor->ne[0];
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); int ny = tensor->ne[1];
if (nx % QK_K != 0) {
LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
convert_incompatible_tensor = true; convert_incompatible_tensor = true;
} else { } else {
++qs.n_k_quantized; ++qs.n_k_quantized;
} }
} }
if (convert_incompatible_tensor) { if (convert_incompatible_tensor) {
switch (new_type) { switch (new_type) {
case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ1_0:
@@ -409,7 +410,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
return new_type; return new_type;
} }
static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) { static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
if (nthread < 2) { if (nthread < 2) {
// single-thread // single-thread
size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
@@ -463,7 +464,7 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
return new_size; return new_size;
} }
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type default_type; ggml_type default_type;
llama_ftype ftype = params->ftype; llama_ftype ftype = params->ftype;
@@ -525,21 +526,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides; auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data(); kv_overrides = v->data();
} }
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
ml.init_mappings(false); // no prefetching ml.init_mappings(false); // no prefetching
llama_model model(llama_model_default_params()); llama_model model;
llm_load_arch (ml, model);
llm_load_hparams(ml, model);
llm_load_stats (ml, model);
model.load_arch (ml); struct quantize_state_internal qs(model, params);
model.load_hparams(ml);
model.load_stats (ml);
struct quantize_state_impl qs(model, params);
if (params->only_copy) { if (params->only_copy) {
ftype = ml.ftype; ftype = model.ftype;
} }
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr; const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
if (params->imatrix) { if (params->imatrix) {
@@ -623,8 +621,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
// sanity checks for models that have attention layers // sanity checks
if (qs.n_attention_wv != 0)
{ {
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads // attention layers have a non-zero number of kv heads
@@ -737,10 +734,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times. // This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.") == std::string::npos;
quantize &= name.find("mm.") == std::string::npos;
// quantize only 2D and 3D tensors (experts) // quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2); quantize &= (ggml_n_dims(tensor) >= 2);
@@ -768,7 +761,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
quantize &= name.find("time_mix_w2.weight") == std::string::npos; quantize &= name.find("time_mix_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
// do not quantize relative position bias (T5) // do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos; quantize &= name.find("attn_rel_b.weight") == std::string::npos;
@@ -847,7 +839,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
} else { } else {
llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
f32_data = (float *) f32_conv_buf.data(); f32_data = (float *) f32_conv_buf.data();
} }
@@ -876,7 +868,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
} }
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
} }
@@ -885,8 +877,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// update the gguf meta data as we go // update the gguf meta data as we go
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
// write tensor data + padding // write tensor data + padding
fout.write((const char *) new_data, new_size); fout.write((const char *) new_data, new_size);
@@ -930,7 +921,7 @@ uint32_t llama_model_quantize(
const char * fname_out, const char * fname_out,
const llama_model_quantize_params * params) { const llama_model_quantize_params * params) {
try { try {
llama_model_quantize_impl(fname_inp, fname_out, params); llama_model_quantize_internal(fname_inp, fname_out, params);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
return 1; return 1;

View File

@@ -257,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
for (int i = 0; i < (int)cur_p->size; ++i) { for (int i = 0; i < (int)cur_p->size; ++i) {
const float val = cur_p->data[i].logit; const float val = cur_p->data[i].logit;
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
ib = std::max(0, std::min(nbuckets - 1, ib)); ib = std::max(0, std::min(nbuckets-1, ib));
bucket_idx[i] = ib; bucket_idx[i] = ib;
++histo[ib]; ++histo[ib];
} }
@@ -280,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
for (int i = 0; i < (int)cur_p->size; ++i) { for (int i = 0; i < (int)cur_p->size; ++i) {
int j = bucket_idx[i]; int j = bucket_idx[i];
if (j >= ib) { if (j >= ib) {
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
} }
} }
ptr = tmp_tokens.data(); ptr = tmp_tokens.data();
int ndone = 0; int ndone = 0;
for (int j = nbuckets - 1; j > ib; --j) { for (int j = nbuckets-1; j > ib; --j) {
std::sort(ptr, ptr + histo[j], comp); std::sort(ptr, ptr + histo[j], comp);
ptr += histo[j]; ptr += histo[j];
ndone += histo[j]; ndone += histo[j];
@@ -316,13 +316,6 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API // llama_sampler API
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}
const char * llama_sampler_name(const struct llama_sampler * smpl) { const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) { if (!smpl->iface) {
return "(null)"; return "(null)";
@@ -354,10 +347,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
} }
if (smpl->ctx == nullptr) { if (smpl->ctx == nullptr) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ smpl->iface, /* .iface = */ smpl->iface,
/* .ctx = */ nullptr /* .ctx = */ nullptr,
); };
} }
GGML_ABORT("the sampler does not support cloning"); GGML_ABORT("the sampler does not support cloning");
@@ -378,10 +371,7 @@ void llama_sampler_free(struct llama_sampler * smpl) {
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
// TODO: do not allocate each time // TODO: do not allocate each time
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@@ -479,15 +469,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
}; };
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_chain_i, /* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain { /* .ctx = */ new llama_sampler_chain {
/* .params = */ params, /* .params = */ params,
/* .samplers = */ {}, /* .samplers = */ {},
/* .t_sample_us = */ 0, /* .t_sample_us = */ 0,
/* .n_sample = */ 0, /* .n_sample = */ 0,
} },
); };
} }
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -553,10 +543,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
}; };
struct llama_sampler * llama_sampler_init_greedy() { struct llama_sampler * llama_sampler_init_greedy() {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_greedy_i, /* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr /* .ctx = */ nullptr,
); };
} }
// dist // dist
@@ -615,14 +605,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i, /* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist { /* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed, /* .seed = */ seed,
/* .seed_cur = */ seed_cur, /* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
} },
); };
} }
// softmax // softmax
@@ -645,10 +635,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
}; };
struct llama_sampler * llama_sampler_init_softmax() { struct llama_sampler * llama_sampler_init_softmax() {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_softmax_i, /* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr /* .ctx = */ nullptr,
); };
} }
// top-k // top-k
@@ -685,12 +675,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
}; };
struct llama_sampler * llama_sampler_init_top_k(int32_t k) { struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_top_k_i, /* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k { /* .ctx = */ new llama_sampler_top_k {
/* .k = */ k, /* .k = */ k,
} },
); };
} }
// top-p // top-p
@@ -751,13 +741,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
}; };
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_top_p_i, /* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p { /* .ctx = */ new llama_sampler_top_p {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
} },
); };
} }
// min-p // min-p
@@ -847,13 +837,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
}; };
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_min_p_i, /* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p { /* .ctx = */ new llama_sampler_min_p {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
} },
); };
} }
// typical // typical
@@ -946,13 +936,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
}; };
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_typical_i, /* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical { /* .ctx = */ new llama_sampler_typical {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
} },
); };
} }
// temp // temp
@@ -990,12 +980,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
}; };
struct llama_sampler * llama_sampler_init_temp(float temp) { struct llama_sampler * llama_sampler_init_temp(float temp) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_temp_i, /* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp { /* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp, /*.temp = */ temp,
} },
); };
} }
// temp-ext // temp-ext
@@ -1100,14 +1090,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
}; };
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_temp_ext_i, /* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext { /* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp, /* .temp = */ temp,
/* .delta = */ delta, /* .delta = */ delta,
/* .exponent = */ exponent, /* .exponent = */ exponent,
} },
); };
} }
// xtc // xtc
@@ -1192,7 +1182,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i, /* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc { /* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p, /* .probability = */ p,
@@ -1201,8 +1191,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed, /* .seed = */ seed,
/* .seed_cur = */ seed_cur, /* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
} },
); };
} }
// mirostat // mirostat
@@ -1299,7 +1289,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i, /* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat { /* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab, /* .n_vocab = */ n_vocab,
@@ -1310,8 +1300,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m, /* .m = */ m,
/* .mu = */ 2.0f*tau, /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
} },
); };
} }
// mirostat v2 // mirostat v2
@@ -1398,7 +1388,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i, /* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 { /* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed, /* .seed = */ seed,
@@ -1407,8 +1397,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta, /* .eta = */ eta,
/* .mu = */ 2.0f*tau, /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
} },
); };
} }
// grammar // grammar
@@ -1440,30 +1430,13 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token
} }
} }
// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
static struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_grammar *) smpl->ctx; auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (!ctx->grammar) { if (!ctx->grammar) {
return; return;
} }
std::vector<const char *> trigger_words; auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
for (auto & word : ctx->grammar->trigger_words) {
trigger_words.push_back(word.c_str());
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar); llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new; ctx->grammar = grammar_new;
@@ -1472,7 +1445,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
// copy the state // copy the state
{ {
@@ -1508,55 +1481,29 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .free = */ llama_sampler_grammar_free, /* .free = */ llama_sampler_grammar_free,
}; };
static struct llama_sampler * llama_sampler_init_grammar_impl( struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
auto * ctx = new llama_sampler_grammar; auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') { if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = { *ctx = {
/* .vocab = */ vocab, /* .vocab = */ &vocab,
/* .grammar_str = */ grammar_str, /* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root, /* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
}; };
} else { } else {
*ctx = { *ctx = {
/* .vocab = */ vocab, /* .vocab = */ &vocab,
/* .grammar_str = */ {}, /* .grammar_str = */ {},
/* .grammar_root = */ {}, /* .grammar_root = */ {},
/* .grammar = */ nullptr, /* .grammar = */ nullptr,
}; };
} }
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_grammar_i, /* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx /* .ctx = */ ctx,
); };
}
struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root) {
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
}
struct llama_sampler * llama_sampler_init_grammar_lazy(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
} }
// penalties // penalties
@@ -1685,7 +1632,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) { float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0); penalty_last_n = std::max(penalty_last_n, 0);
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i, /* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties { /* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n, /* .penalty_last_n = */ penalty_last_n,
@@ -1694,75 +1641,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present, /* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n), /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {}, /* .token_count = */ {},
} },
); };
}
// top-n-sigma
struct llama_sampler_top_n_sigma {
const float n;
};
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
return "top-n-sigma";
}
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
// find max logit and calculate mean
float max = cur_p->data[0].logit;
float logits_sum = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max) {
max = cur_p->data[i].logit;
}
logits_sum += cur_p->data[i].logit;
}
float mean = logits_sum/cur_p->size;
// calculate standard deviation
float acc = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
acc += pow(cur_p->data[i].logit - mean, 2);
}
float std = sqrt(acc/cur_p->size);
//apply mask
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit < max - (ctx->n * std)) {
cur_p->data[i].logit = -INFINITY;
}
}
llama_sampler_softmax_impl(cur_p);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
return llama_sampler_init_top_n_sigma(ctx->n);
}
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_n_sigma *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
/* .name = */ llama_sampler_top_n_sigma_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_n_sigma_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_n_sigma_clone,
/* .free = */ llama_sampler_top_n_sigma_free,
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_n_sigma_i,
/* .ctx = */ new llama_sampler_top_n_sigma {
/* .n = */ n,
}
);
} }
// DRY // DRY
@@ -1783,8 +1663,8 @@ struct llama_sampler_dry {
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) { for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
std::string word = vocab.detokenize({token_id}, true); std::string word = llama_detokenize(vocab, {token_id}, true);
if (word.find(str) != std::string::npos) { if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>()); token_sequences.emplace(token_id, std::vector<llama_token>());
} else { } else {
@@ -1801,7 +1681,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
} }
} }
if (match) { if (match) {
std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false); std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
tokenization.resize(max_tail_len); tokenization.resize(max_tail_len);
} }
@@ -1952,7 +1832,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
if (n > 0) { if (n > 0) {
lt = k; lt = k;
rt = k + n - 1; rt = k+n-1;
} }
} else { } else {
// If k is inside the current Z-box, consider two cases. // If k is inside the current Z-box, consider two cases.
@@ -2057,7 +1937,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
// Copy the state, including the processed breakers // Copy the state, including the processed breakers
{ {
@@ -2084,7 +1964,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
/* .free = */ llama_sampler_dry_free, /* .free = */ llama_sampler_dry_free,
}; };
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40; const int MAX_CHAR_LEN = 40;
@@ -2111,11 +1991,11 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
sequence_break.resize(MAX_CHAR_LEN); sequence_break.resize(MAX_CHAR_LEN);
} }
get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
} }
} }
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_dry_i, /* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry { /* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size, /* .total_context_size = */ context_size,
@@ -2127,14 +2007,14 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {}, /* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
} },
); };
} }
// wrapper for test-sampling.cpp // wrapper for test-sampling.cpp
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) { struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
auto * ctx = (llama_sampler_dry *) result->ctx; auto * ctx = (llama_sampler_dry *) result->ctx;
// Process the token-based sequence breakers // Process the token-based sequence breakers
@@ -2229,14 +2109,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab, int32_t n_vocab,
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias) { const llama_logit_bias * logit_bias) {
return llama_sampler_init( return new llama_sampler {
/* .iface = */ &llama_sampler_logit_bias_i, /* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias { /* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab, /* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {}, /* .to_search = */ {},
} },
); };
} }
// infill // infill
@@ -2273,7 +2153,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
float p_eog_sum = 0.0f; float p_eog_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
if (ctx->vocab->is_eog(cur_p->data[i].id)) { if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p; p_eog_sum += cur_p->data[i].p;
} else { } else {
p_txt_sum += cur_p->data[i].p; p_txt_sum += cur_p->data[i].p;
@@ -2295,7 +2175,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
float p_sum = 0.0f; float p_sum = 0.0f;
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
if (ctx->vocab->is_eog(cur_p->data[i].id)) { if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_sum += cur_p->data[i].p; p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i]; cur_p->data[cur_p->size++] = cur_p->data[i];
@@ -2323,17 +2203,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
continue; continue;
} }
int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
if (len0 < 0) { if (len0 < 0) {
ctx->buf0.resize(len0); ctx->buf0.resize(len0);
len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
assert(len0 > 0); assert(len0 > 0);
} }
int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) { if (len1 < 0) {
ctx->buf1.resize(len1); ctx->buf1.resize(len1);
len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0); assert(len1 > 0);
} }
@@ -2368,7 +2248,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) { if (cur_p->data[i].p < thold && !is_eog) {
continue; continue;
@@ -2389,7 +2269,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
// if no non-EOG tokens are left -> reduce cur_p to single EOT token // if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) { if (n_non_eog == 0) {
cur_p->size = 1; cur_p->size = 1;
cur_p->data[0].id = ctx->vocab->token_eot(); cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f; cur_p->data[0].logit = 1.0f;
return; return;
@@ -2411,7 +2291,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) { if (cur_p->data[i].p < thold && !is_eog) {
continue; continue;
@@ -2434,7 +2314,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx; const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill(ctx->vocab); return llama_sampler_init_infill_impl(*ctx->vocab);
} }
static void llama_sampler_infill_free(struct llama_sampler * smpl) { static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -2450,15 +2330,16 @@ static struct llama_sampler_i llama_sampler_infill_i = {
/* .free = */ llama_sampler_infill_free, /* .free = */ llama_sampler_infill_free,
}; };
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { struct llama_sampler * llama_sampler_init_infill_impl(
return llama_sampler_init( const struct llama_vocab & vocab) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab, /* .vocab = */ &vocab,
/* .buf0 = */ std::vector<char>(512), /* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512), /* .buf1 = */ std::vector<char>(512),
} },
); };
} }
// utils // utils

View File

@@ -2,9 +2,7 @@
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
#include "llama.h" #include "llama-grammar.h"
#include <vector>
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
@@ -23,6 +21,24 @@ struct llama_sampler_chain {
mutable int32_t n_sample; mutable int32_t n_sample;
}; };
struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
struct llama_sampler * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
struct llama_sampler * llama_sampler_init_dry_testing( struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size, int32_t context_size,
float dry_multiplier, float dry_multiplier,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,122 +4,179 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <unordered_map>
#include <map>
#include <set>
struct LLM_KV; static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
struct llama_model_loader; switch (type) {
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
default: return "unknown";
}
}
struct llm_tokenizer;
struct llama_vocab { struct llama_vocab {
using id = llama_token;
using token = std::string;
using tattr = llama_token_attr;
struct token_data { struct token_data {
std::string text; token text;
float score; float score;
llama_token_attr attr; tattr attr;
}; };
llama_vocab(); uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
int max_token_len = 0; // used for optimizing longest token search
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
std::vector<id> cache_special_tokens;
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens
// TODO: should we set all of these to LLAMA_TOKEN_NULL?
id special_bos_id = 1;
id special_eos_id = 2;
id special_eot_id = LLAMA_TOKEN_NULL;
id special_eom_id = LLAMA_TOKEN_NULL;
id special_unk_id = 0;
id special_sep_id = LLAMA_TOKEN_NULL;
id special_pad_id = LLAMA_TOKEN_NULL;
id special_cls_id = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
id special_mask_id = LLAMA_TOKEN_NULL;
id linefeed_id = 13;
// fim tokens
id special_fim_pre_id = LLAMA_TOKEN_NULL;
id special_fim_suf_id = LLAMA_TOKEN_NULL;
id special_fim_mid_id = LLAMA_TOKEN_NULL;
id special_fim_pad_id = LLAMA_TOKEN_NULL;
id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
// set of all tokens that cause "end of generation"
std::set<id> special_eog_ids;
// tokenizer flags
bool tokenizer_add_space_prefix = false;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_remove_extra_whitespaces = false;
bool tokenizer_escape_whitespaces = true;
bool tokenizer_treat_whitespace_as_suffix = false;
std::vector<char> precompiled_charsmap;
llm_tokenizer * tokenizer = nullptr;
llama_vocab() = default;
~llama_vocab(); ~llama_vocab();
void load(llama_model_loader & ml, const LLM_KV & kv);
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
uint32_t n_tokens() const;
uint32_t n_token_types() const;
std::string type_name() const;
bool is_normal (llama_token id) const;
bool is_unknown (llama_token id) const;
bool is_control (llama_token id) const;
bool is_byte (llama_token id) const;
bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const;
uint8_t token_to_byte(llama_token id) const;
llama_token byte_to_token(uint8_t ch) const;
llama_token text_to_token(const std::string & text) const;
const token_data & get_token_data(llama_token id) const;
const char * token_get_text (llama_token id) const;
float token_get_score(llama_token id) const;
llama_token_attr token_get_attr (llama_token id) const;
llama_token token_bos() const;
llama_token token_eos() const;
llama_token token_eot() const;
llama_token token_eom() const;
llama_token token_unk() const;
llama_token token_sep() const;
llama_token token_nl () const;
llama_token token_pad() const;
llama_token token_prefix() const;
llama_token token_middle() const;
llama_token token_suffix() const;
llama_token token_fim_pre() const;
llama_token token_fim_suf() const;
llama_token token_fim_mid() const;
llama_token token_fim_pad() const;
llama_token token_fim_rep() const;
llama_token token_fim_sep() const;
bool get_add_space_prefix () const;
bool get_add_bos () const;
bool get_add_eos () const;
bool get_ignore_merges () const;
bool get_clean_spaces () const;
bool get_remove_extra_whitespaces () const;
bool get_escape_whitespaces () const;
bool get_treat_whitespace_as_suffix() const;
int max_token_len() const;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
int32_t tokenize( void init_tokenizer();
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) const;
std::vector<llama_token> tokenize(
const std::string & raw_text,
bool add_special,
bool parse_special = false) const;
// does not write null-terminator to buf
int32_t token_to_piece(
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) const;
// use cached data
const std::string & token_to_piece(llama_token token) const;
int32_t detokenize(
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) const;
std::string detokenize(
const std::vector<llama_token> & tokens,
bool special) const;
void print_info() const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
//
// internal API
//
// TODO: rename to llama_tokenize_impl
// TODO: This should probably be in llama.h
std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab,
std::string raw_text,
bool add_special,
bool parse_special = false);
// TODO: move the API below as member functions of llama_vocab
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
// does not write null-terminator to buf
int32_t llama_token_to_piece_impl(
const struct llama_vocab & vocab,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special);
// check if token0 is contained as a prefix in token1
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1);
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special);
std::string llama_detokenize(
const struct llama_vocab & vocab,
const std::vector<llama_token> & tokens,
bool special);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -12,17 +12,18 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <codecvt>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <locale>
#include <map> #include <map>
#include <regex> #include <regex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <locale>
#include <codecvt>
size_t unicode_len_utf8(char src) { size_t unicode_len_utf8(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -640,14 +641,7 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
result.reserve(utf8.size()); result.reserve(utf8.size());
size_t offset = 0; size_t offset = 0;
while (offset < utf8.size()) { while (offset < utf8.size()) {
try { result.push_back(unicode_cpt_from_utf8(utf8, offset));
result.push_back(unicode_cpt_from_utf8(utf8, offset));
}
catch (const std::invalid_argument & /*ex*/) {
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
++offset;
result.emplace_back(0xFFFD); // replacement character
}
} }
return result; return result;
} }
@@ -730,7 +724,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
const auto cpts = unicode_cpts_from_utf8(text); const auto cpts = unicode_cpts_from_utf8(text);
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935 // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed; std::string text_collapsed;
if (need_collapse) { if (need_collapse) {
// collapse all unicode categories // collapse all unicode categories

View File

@@ -14,52 +14,83 @@ package llama
#include "llama.h" #include "llama.h"
#include "clip.h" #include "clip.h"
#include "llava.h" #include "llava.h"
#include "gguf.h"
#include "mllama.h" #include "mllama.h"
#include "sampling_ext.h" #include "sampling_ext.h"
extern bool llamaProgressCallback(float progress, void *user_data); extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data); extern void llamaLog(int level, char* text, void* user_data);
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
COMPILER inline get_compiler() {
#if defined(__clang__)
return COMP_CLANG;
#elif defined(__GNUC__)
return COMP_GCC;
#else
return UNKNOWN_COMPILER;
#endif
}
*/ */
import "C" import "C"
import ( import (
"context"
_ "embed" _ "embed"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"os" "os"
"runtime" "runtime"
"runtime/cgo" "runtime/cgo"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"unsafe" "unsafe"
_ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava" _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src" _ "github.com/ollama/ollama/llama/llama.cpp/src"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
) )
func init() {
C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}
//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
// slog levels zeros INFO and are multiples of 4
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
fmt.Fprint(os.Stderr, C.GoString(text))
}
}
func BackendInit() { func BackendInit() {
ggml.OnceLoad() ggml.OnceLoad()
C.llama_backend_init() C.llama_backend_init()
} }
func PrintSystemInfo() string {
var compiler string
switch C.get_compiler() {
case C.COMP_UNKNOWN:
compiler = "cgo(unknown_compiler)"
case C.COMP_GCC:
compiler = "cgo(gcc)"
case C.COMP_CLANG:
compiler = "cgo(clang)"
}
return C.GoString(C.llama_print_system_info()) + compiler
}
var logLevel atomic.Int32
func init() {
logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
}
func EnableDebug() {
logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
}
//export llamaLog
func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
if level < logLevel.Load() {
return
}
fmt.Fprint(os.Stderr, C.GoString(text))
}
func GetModelArch(modelPath string) (string, error) { func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath) mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp)) defer C.free(unsafe.Pointer(mp))
@@ -166,10 +197,6 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c) C.llama_kv_cache_defrag(c.c)
} }
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id // Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
@@ -241,7 +268,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.progress_callback_user_data = unsafe.Pointer(&handle) cparams.progress_callback_user_data = unsafe.Pointer(&handle)
} }
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)} m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
if m.c == nil { if m.c == nil {
return nil, fmt.Errorf("unable to load model: %s", modelPath) return nil, fmt.Errorf("unable to load model: %s", modelPath)
} }
@@ -249,27 +276,13 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return &m, nil return &m, nil
} }
func LoadVocabFromFile(path string) (*Vocab, error) {
mp := C.CString(path)
defer C.free(unsafe.Pointer(mp))
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
if v.c == nil {
return nil, fmt.Errorf("unable to load vocab: %s", path)
}
return &v, nil
}
func FreeVocab(vocab *Vocab) {
C.llama_free_vocab(vocab.c)
}
func FreeModel(model *Model) { func FreeModel(model *Model) {
C.llama_model_free(model.c) C.llama_free_model(model.c)
} }
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) { func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
c := Context{ c := Context{
c: C.llama_init_from_model(model.c, params.c), c: C.llama_new_context_with_model(model.c, params.c),
numThreads: int(params.c.n_threads), numThreads: int(params.c.n_threads),
} }
if c.c == nil { if c.c == nil {
@@ -280,29 +293,29 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
} }
func (m *Model) NumVocab() int { func (m *Model) NumVocab() int {
return int(C.llama_vocab_n_tokens(m.Vocab())) return int(C.llama_n_vocab(m.c))
} }
func (m *Model) TokenIsEog(token int) bool { func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token))) return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
} }
func (m *Model) AddBOSToken() bool { func (m *Model) AddBOSToken() bool {
return bool(C.llama_vocab_get_add_bos(m.Vocab())) return bool(C.llama_add_bos_token(m.c))
} }
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error { func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
cLoraPath := C.CString(loraPath) cLoraPath := C.CString(loraPath)
defer C.free(unsafe.Pointer(cLoraPath)) defer C.free(unsafe.Pointer(cLoraPath))
loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath) loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
if loraAdapter == nil { if loraAdapter == nil {
return errors.New("unable to load lora") return errors.New("unable to load lora")
} }
err := -1 err := -1
if loraAdapter != nil { if loraAdapter != nil {
err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale))) err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
} }
if err != 0 { if err != 0 {
return errors.New("error applying lora from file") return errors.New("error applying lora from file")
@@ -311,14 +324,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return nil return nil
} }
type Vocab struct {
c *C.struct_llama_vocab
}
func (m *Model) Vocab() *C.struct_llama_vocab {
return C.llama_model_get_vocab(m.c)
}
type Batch struct { type Batch struct {
c C.struct_llama_batch c C.struct_llama_batch
batchSize int batchSize int
@@ -409,7 +414,7 @@ func (m *Model) TokenToPiece(token int) string {
tokenLen := 12 tokenLen := 12
buf := make([]byte, tokenLen) buf := make([]byte, tokenLen)
tokenLen = int(C.llama_token_to_piece( tokenLen = int(C.llama_token_to_piece(
m.Vocab(), m.c,
C.int32_t(token), C.int32_t(token),
(*C.char)(unsafe.Pointer(&buf[0])), (*C.char)(unsafe.Pointer(&buf[0])),
C.int32_t(tokenLen), C.int32_t(tokenLen),
@@ -421,7 +426,7 @@ func (m *Model) TokenToPiece(token int) string {
buf = make([]byte, tokenLen) buf = make([]byte, tokenLen)
C.llama_token_to_piece( C.llama_token_to_piece(
m.Vocab(), m.c,
C.int32_t(token), C.int32_t(token),
(*C.char)(unsafe.Pointer(&buf[0])), (*C.char)(unsafe.Pointer(&buf[0])),
C.int32_t(tokenLen), C.int32_t(tokenLen),
@@ -439,7 +444,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
defer C.free(unsafe.Pointer(cText)) defer C.free(unsafe.Pointer(cText))
result := C.llama_tokenize( result := C.llama_tokenize(
m.Vocab(), m.c,
cText, cText,
C.int32_t(len(text)), C.int32_t(len(text)),
&cTokens[0], &cTokens[0],
@@ -453,7 +458,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
maxTokens = int(-result) maxTokens = int(-result)
cTokens = make([]C.llama_token, maxTokens) cTokens = make([]C.llama_token, maxTokens)
result = C.llama_tokenize( result = C.llama_tokenize(
m.Vocab(), m.c,
cText, cText,
C.int32_t(len(text)), C.int32_t(len(text)),
&cTokens[0], &cTokens[0],
@@ -475,7 +480,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
} }
func (m *Model) NEmbd() int { func (m *Model) NEmbd() int {
return int(C.llama_model_n_embd(m.c)) return int(C.llama_n_embd(m.c))
} }
func Quantize(infile, outfile string, ftype uint32) error { func Quantize(infile, outfile string, ftype uint32) error {
@@ -691,53 +696,3 @@ func SchemaToGrammar(schema []byte) []byte {
} }
return buf[:n] return buf[:n]
} }
type Sampler struct {
c *C.struct_llama_sampler
}
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
cGrammar := C.CString(grammar)
cRoot := C.CString("root")
defer C.free(unsafe.Pointer(cGrammar))
defer C.free(unsafe.Pointer(cRoot))
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
return sampler
}
func (s *Sampler) Accept(token int32) {
C.llama_sampler_accept(s.c, C.llama_token(token))
}
type TokenData struct {
Id int32
Logit float32
}
func (s *Sampler) Apply(tokens []TokenData) {
tds := make([]C.struct_llama_token_data, len(tokens))
for i, token := range tokens {
tds[i] = C.struct_llama_token_data{
id: C.int32_t(token.Id),
logit: C.float(token.Logit),
p: C.float(0.0),
}
}
tda := &C.llama_token_data_array{
data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
size: C.size_t(len(tokens)),
selected: C.int64_t(-1),
sorted: C.bool(false),
}
var pinner runtime.Pinner
pinner.Pin(&tds[0])
defer pinner.Unpin()
C.llama_sampler_apply(s.c, tda)
for i := range tokens {
tokens[i].Logit = float32(tds[i].logit)
}
}

1
llama/mllama.cpp vendored
View File

@@ -5,7 +5,6 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml.h" #include "ggml.h"
#include "gguf.h"
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
#include "ggml-cuda.h" #include "ggml-cuda.h"

Some files were not shown because too many files have changed in this diff Show More