mirror of
https://github.com/ollama/ollama.git
synced 2025-12-29 02:28:12 -05:00
Compare commits
1 Commits
jmorganca/
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19638cec55 |
28
.github/workflows/release.yaml
vendored
28
.github/workflows/release.yaml
vendored
@@ -65,36 +65,14 @@ jobs:
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v12'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v13'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runner_dir: ''
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -118,7 +96,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
@@ -160,7 +138,7 @@ jobs:
|
||||
run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
env:
|
||||
@@ -254,7 +232,7 @@ jobs:
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
16
.github/workflows/test.yaml
vendored
16
.github/workflows/test.yaml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
@@ -78,17 +78,8 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
@@ -111,8 +102,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
|
||||
@@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
@@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
|
||||
@@ -18,14 +18,6 @@
|
||||
"name": "CUDA",
|
||||
"inherits": [ "Default" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
@@ -34,14 +26,6 @@
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
@@ -88,21 +72,11 @@
|
||||
"configurePreset": "CUDA",
|
||||
"targets": [ "ggml-cuda" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 11"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 13"
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
58
Dockerfile
58
Dockerfile
@@ -1,7 +1,6 @@
|
||||
# vim: filetype=dockerfile
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
ARG PARALLEL=8
|
||||
|
||||
ARG ROCMVERSION=6.3.3
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
@@ -35,51 +34,26 @@ ENV LDFLAGS=-s
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel 8
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||
ARG CMAKEVERSION
|
||||
@@ -87,11 +61,10 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 5' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --parallel --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||
ARG CMAKEVERSION
|
||||
@@ -99,11 +72,10 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --parallel --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -120,14 +92,10 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||
|
||||
|
||||
@@ -411,10 +411,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -541,9 +537,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -604,7 +597,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
|
||||
### Supported backends
|
||||
|
||||
|
||||
@@ -222,17 +222,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
return AuthorizationError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
PublicKey: pubKey,
|
||||
}
|
||||
} else if response.StatusCode >= http.StatusBadRequest {
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
@@ -438,16 +428,3 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// Signout will disconnect an ollama instance from ollama.com
|
||||
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
|
||||
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
var resp UserResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
171
api/types.go
171
api/types.go
@@ -11,8 +11,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -38,19 +36,6 @@ func (e StatusError) Error() string {
|
||||
}
|
||||
}
|
||||
|
||||
type AuthorizationError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
PublicKey string `json:"public_key"`
|
||||
}
|
||||
|
||||
func (e AuthorizationError) Error() string {
|
||||
if e.Status != "" {
|
||||
return e.Status
|
||||
}
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
@@ -105,10 +90,6 @@ type GenerateRequest struct {
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -139,10 +120,6 @@ type ChatRequest struct {
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -301,23 +278,16 @@ func mapToTypeScriptType(jsonType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
@@ -328,38 +298,16 @@ func (t *ToolFunction) String() string {
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Message contains the message or part of a message from the model.
|
||||
Message Message `json:"message"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
// DoneReason is the reason the model stopped generating text.
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
// DebugInfo contains debug information for template rendering
|
||||
type DebugInfo struct {
|
||||
RenderedTemplate string `json:"rendered_template"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
@@ -412,12 +360,8 @@ type EmbedRequest struct {
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Truncate truncates the input to fit the model's max sequence length.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Dimensions truncates the output embedding to the specified dimension.
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
@@ -455,47 +399,18 @@ type EmbeddingResponse struct {
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
// Model is the model name to create.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
|
||||
// From is the name of the model or file to use as the source.
|
||||
From string `json:"from,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream ollama API for the model (if any).
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// Files is a map of files include when creating the model.
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
|
||||
// Adapters is a map of LoRA adapters to include when creating the model.
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
|
||||
// Template is the template used when constructing a request to the model.
|
||||
Template string `json:"template,omitempty"`
|
||||
|
||||
// License is a string or list of strings for licenses.
|
||||
License any `json:"license,omitempty"`
|
||||
|
||||
// System is the system prompt for the model.
|
||||
System string `json:"system,omitempty"`
|
||||
|
||||
// Parameters is a map of hyper-parameters which are applied to the model.
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
|
||||
// Messages is a list of messages added to the model before chat and generation requests.
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
From string `json:"from,omitempty"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
License any `json:"license,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -533,12 +448,8 @@ type ShowResponse struct {
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
@@ -597,14 +508,12 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
@@ -628,12 +537,6 @@ type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -657,8 +560,6 @@ type GenerateResponse struct {
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
@@ -671,18 +572,6 @@ type ModelDetails struct {
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
// UserResponse provides information about a user.
|
||||
type UserResponse struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
AvatarURL string `json:"avatarurl,omitempty"`
|
||||
FirstName string `json:"firstname,omitempty"`
|
||||
LastName string `json:"lastname,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
@@ -971,7 +860,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
|
||||
@@ -17,11 +17,6 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req string
|
||||
exp *Duration
|
||||
}{
|
||||
{
|
||||
name: "Unset",
|
||||
req: `{ }`,
|
||||
exp: nil,
|
||||
},
|
||||
{
|
||||
name: "Positive Integer",
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
@@ -30,7 +25,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42500 * time.Millisecond},
|
||||
exp: &Duration{42 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
@@ -441,50 +436,3 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params ToolFunctionParameters
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple object with string property",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
{
|
||||
name: "marshal failure returns empty string",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: func() any {
|
||||
// Create a cycle that will cause json.Marshal to fail
|
||||
type selfRef struct {
|
||||
Self *selfRef
|
||||
}
|
||||
s := &selfRef{}
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.params.String()
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
25
auth/auth.go
25
auth/auth.go
@@ -19,31 +19,6 @@ import (
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
fileIsReadable := func(fp string) bool {
|
||||
info, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that it's a regular file, not a directory or other file type
|
||||
if !info.Mode().IsRegular() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try to open it to check readability
|
||||
file, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
file.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
||||
if fileIsReadable(systemPath) {
|
||||
return systemPath, nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
161
cmd/cmd.go
161
cmd/cmd.go
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -37,7 +35,6 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -50,8 +47,6 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
@@ -61,8 +56,10 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
return
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
@@ -291,17 +288,7 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
if r.RemoteModel != "" && opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
}
|
||||
|
||||
func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -322,10 +309,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@@ -383,7 +369,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.ShowConnect = false
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
@@ -450,21 +435,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
// the server and the client both have the same public key
|
||||
if pubKey == sErr.PublicKey {
|
||||
h, _ := os.Hostname()
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -485,56 +455,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.Name != "" {
|
||||
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
h, _ := os.Hostname()
|
||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.Signout(cmd.Context(), encKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("You have signed out of ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -587,8 +507,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
errStr := strings.ToLower(err.Error())
|
||||
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
return err
|
||||
@@ -622,14 +541,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||
var size string
|
||||
if m.RemoteModel != "" {
|
||||
size = "-"
|
||||
} else {
|
||||
size = format.HumanBytes(m.Size)
|
||||
}
|
||||
|
||||
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -714,8 +626,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -826,36 +738,12 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
@@ -1103,7 +991,6 @@ type runOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -1659,22 +1546,6 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
signinCmd := &cobra.Command{
|
||||
Use: "signin",
|
||||
Short: "Sign in to ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
signoutCmd := &cobra.Command{
|
||||
Use: "signout",
|
||||
Short: "Sign out from ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
@@ -1769,8 +1640,6 @@ func NewCLI() *cobra.Command {
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
signoutCmd,
|
||||
listCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
|
||||
@@ -3,7 +3,6 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -305,8 +304,6 @@ func TestDeleteHandler(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
errPayload := `{"error":"model '%s' not found"}`
|
||||
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -349,7 +346,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
err := DeleteHandler(cmd, []string{"test-model-not-found"})
|
||||
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
|
||||
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
|
||||
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -502,7 +499,7 @@ func TestPushHandler(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
|
||||
"error": "access denied",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -525,7 +522,6 @@ func TestPushHandler(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
initializeKeypair()
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
|
||||
@@ -28,7 +28,6 @@ type bertModel struct {
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
normalizeEmbeddings bool
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
@@ -55,11 +54,9 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
pooling = m.Path
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +90,6 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
|
||||
@@ -15,24 +15,19 @@ import (
|
||||
|
||||
type gptossModel struct {
|
||||
ModelParameters
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
LocalExperts uint32 `json:"num_local_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
@@ -41,11 +36,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
|
||||
kv["gptoss.block_count"] = m.HiddenLayers
|
||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||
kv["gptoss.expert_count"] = m.Experts
|
||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||
@@ -54,7 +49,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
|
||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
@@ -97,11 +92,6 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name += ".weight"
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
@@ -114,47 +104,25 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
|
||||
func (m *gptossModel) Replacements() []string {
|
||||
var replacements []string
|
||||
if m.MaxPositionEmbeddings > 0 {
|
||||
// hf flavored model
|
||||
replacements = []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
"self_attn.sinks", "attn_sinks",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
} else {
|
||||
replacements = []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
return []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
@@ -172,20 +140,7 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
@@ -219,5 +174,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(u8s)), nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
|
||||
const (
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindMXFP4 = 4
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
|
||||
@@ -96,7 +96,7 @@ type safetensor struct {
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindFP32:
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
@@ -230,65 +230,3 @@ func TestSafetensors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
st safetensor
|
||||
expected uint32
|
||||
}{
|
||||
{
|
||||
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindBF16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with v. prefix should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "v.weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with FP32 base kind should return FP32",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10}, // will default to FP32
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP32,
|
||||
},
|
||||
{
|
||||
name: "Non-BF16 dtype should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "FP16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.st.Kind()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,7 +277,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: ID,
|
||||
filterID: gpuOrdinalID,
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
@@ -395,7 +394,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
|
||||
// Check for env var workarounds
|
||||
if name == "1002:687f" { // Vega RX 56
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
@@ -524,26 +523,19 @@ func verifyKFDDriverAccess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
@@ -111,7 +111,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
UnreliableFreeMemory: true,
|
||||
|
||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||
filterID: i,
|
||||
DependencyPath: []string{libDir},
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
@@ -201,26 +200,19 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,20 @@ import (
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func cudaVariant(gpuInfos []CudaGPUInfo) string {
|
||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "cuda" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
@@ -43,22 +56,14 @@ func cudaVariant(gpuInfos []CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return "sbsa"
|
||||
}
|
||||
|
||||
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu
|
||||
for _, gpuInfo := range gpuInfos {
|
||||
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
|
||||
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
|
||||
return "v12"
|
||||
}
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
// The detected driver is older than Feb 2023
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
return "v11"
|
||||
}
|
||||
|
||||
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
|
||||
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
|
||||
// The detected driver is older than 580 (Aug 2025)
|
||||
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
|
||||
return "v12"
|
||||
}
|
||||
return "v13"
|
||||
return "v12"
|
||||
}
|
||||
|
||||
@@ -284,8 +284,18 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
variant := cudaVariant(gpuInfo)
|
||||
|
||||
// Start with our bundled libraries
|
||||
if variant != "" {
|
||||
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||
if _, err := os.Stat(variantPath); err == nil {
|
||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
|
||||
}
|
||||
}
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.Variant = variant
|
||||
|
||||
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
|
||||
unsupportedGPUs = append(unsupportedGPUs,
|
||||
@@ -323,24 +333,6 @@ func GetGPUInfo() GpuInfoList {
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
|
||||
variant := cudaVariant(cudaGPUs)
|
||||
var variantPath string
|
||||
// Start with our bundled libraries
|
||||
if variant != "" {
|
||||
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||
if _, err := os.Stat(variantPath); err != nil {
|
||||
variantPath = ""
|
||||
}
|
||||
}
|
||||
|
||||
for i := range cudaGPUs {
|
||||
cudaGPUs[i].Variant = variant
|
||||
if variantPath != "" {
|
||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Intel
|
||||
@@ -379,15 +371,6 @@ func GetGPUInfo() GpuInfoList {
|
||||
}
|
||||
|
||||
rocmGPUs, err = AMDGetGPUInfo()
|
||||
|
||||
// The ID field is used in context of the filtered set of GPUS
|
||||
// so we have to replace any of these numeric IDs with their
|
||||
// placement in this set of GPUs
|
||||
for i := range rocmGPUs {
|
||||
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
|
||||
rocmGPUs[i].ID = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
@@ -697,16 +680,23 @@ func getVerboseState() C.uint16_t {
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
return "", ""
|
||||
}
|
||||
vd := []string{}
|
||||
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
|
||||
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
|
||||
vd = append(vd, tmp)
|
||||
switch l[0].Library {
|
||||
case "cuda":
|
||||
return cudaGetVisibleDevicesEnv(l)
|
||||
case "rocm":
|
||||
return rocmGetVisibleDevicesEnv(l)
|
||||
case "oneapi":
|
||||
return oneapiGetVisibleDevicesEnv(l)
|
||||
default:
|
||||
slog.Debug("no filter required for library " + l[0].Library)
|
||||
return "", ""
|
||||
}
|
||||
return vd
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
||||
@@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
// No-op on darwin
|
||||
return nil
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
||||
21
discover/gpu_oneapi.go
Normal file
21
discover/gpu_oneapi.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "oneapi" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
||||
}
|
||||
@@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath []string `json:"lib_path,omitempty"`
|
||||
|
||||
// Extra environment variables specific to the GPU as list of [key=value]
|
||||
EnvWorkarounds []string `json:"envs,omitempty"`
|
||||
// Extra environment variables specific to the GPU as list of [key,value]
|
||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||
|
||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||
@@ -36,10 +36,9 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
|
||||
@@ -1708,7 +1708,6 @@ Advanced parameters:
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
### Examples
|
||||
|
||||
|
||||
@@ -11,10 +11,6 @@ Then build and run Ollama from the root directory of the repository:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||
|
||||
|
||||
## macOS (Apple Silicon)
|
||||
|
||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||
|
||||
75
docs/docs.json
Normal file
75
docs/docs.json
Normal file
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"$schema": "https://mintlify.com/docs.json",
|
||||
"theme": "mint",
|
||||
"background": {
|
||||
"color": {
|
||||
"light": "#ffffff",
|
||||
"dark": "#000000"
|
||||
}
|
||||
},
|
||||
"appearance": {
|
||||
"default": "light"
|
||||
},
|
||||
"styling": {
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy", "chatgpt", "claude", "view"]
|
||||
},
|
||||
"fonts": {
|
||||
"heading": {
|
||||
"family": "Inter"
|
||||
},
|
||||
"body": {
|
||||
"family": "Inter"
|
||||
}
|
||||
},
|
||||
"name": "Ollama",
|
||||
"colors": {
|
||||
"primary": "#000",
|
||||
"light": "#b5b5b5",
|
||||
"dark": "#fff"
|
||||
},
|
||||
"favicon": "/ollama.png",
|
||||
"logo": {
|
||||
"light": "/ollama.png",
|
||||
"dark": "/favicon.svg"
|
||||
},
|
||||
"navigation": {
|
||||
"tabs": [
|
||||
{
|
||||
"tab": "Documentation",
|
||||
"groups": [
|
||||
{
|
||||
"group": "Home",
|
||||
"pages": ["index", "quickstart", "faq", "troubleshooting"]
|
||||
},
|
||||
{
|
||||
"group": "Platforms",
|
||||
"pages": ["linux", "windows", "docker"]
|
||||
},
|
||||
{
|
||||
"group": "Features",
|
||||
"pages": [
|
||||
"modelfile",
|
||||
"apis",
|
||||
"openai",
|
||||
"import",
|
||||
"gpu",
|
||||
"benchmark"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "Development",
|
||||
"groups": [
|
||||
{
|
||||
"group": " ",
|
||||
"pages": ["development", "examples", "template"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -11,13 +11,12 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||
sudo rm -rf /usr/lib/ollama
|
||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||
```
|
||||
|
||||
|
||||
@@ -92,9 +92,6 @@ If none of those resolve the problem, gather additional information and file an
|
||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||
|
||||
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
|
||||
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
|
||||
|
||||
|
||||
## AMD GPU Discovery
|
||||
|
||||
|
||||
@@ -134,17 +134,6 @@ func LoadTimeout() (loadTimeout time.Duration) {
|
||||
return loadTimeout
|
||||
}
|
||||
|
||||
func Remotes() []string {
|
||||
var r []string
|
||||
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
|
||||
if raw == "" {
|
||||
r = []string{"ollama.com"}
|
||||
} else {
|
||||
r = strings.Split(raw, ",")
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func Bool(k string) func() bool {
|
||||
return func() bool {
|
||||
if s := Var(k); s != "" {
|
||||
@@ -196,6 +185,8 @@ var (
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable the new memory estimation logic
|
||||
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -281,7 +272,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
204
fs/ggml/ggml.go
204
fs/ggml/ggml.go
@@ -7,11 +7,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
@@ -57,28 +55,10 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() []uint64 {
|
||||
headCountDefault := uint32(1)
|
||||
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
|
||||
if len(headCount) == 1 {
|
||||
headCountDefault = headCount[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCount) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCount) {
|
||||
out[i] = uint64(headCountDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCount[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
@@ -86,27 +66,6 @@ func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() []uint64 {
|
||||
headCountKVDefault := uint32(1)
|
||||
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
|
||||
if len(headCountKV) == 1 {
|
||||
headCountKVDefault = headCountKV[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCountKV) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCountKV) {
|
||||
out[i] = uint64(headCountKVDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCountKV[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
@@ -139,26 +98,6 @@ func (kv KV) ChatTemplate() string {
|
||||
return kv.String("tokenizer.chat_template")
|
||||
}
|
||||
|
||||
// ssm architecture parameters
|
||||
|
||||
func (kv KV) SSMConvKernel() uint64 {
|
||||
return uint64(kv.Uint("ssm.conv_kernel"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMInnerSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.inner_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMStateSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.state_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
@@ -190,27 +129,22 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
|
||||
return slices.Min(arrVal), slices.Max(arrVal)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return []uint32{u32}
|
||||
return u32, u32
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
return u32s.values
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
dst := make([]uint32, len(i32s.values))
|
||||
for i, v := range i32s.values {
|
||||
if v < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
|
||||
}
|
||||
dst[i] = uint32(v)
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
}
|
||||
return dst
|
||||
return uint32(min), uint32(max)
|
||||
}
|
||||
|
||||
return []uint32{defaultValue}
|
||||
return defaultValue, defaultValue
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
@@ -243,7 +177,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
@@ -342,7 +275,7 @@ type Tensor struct {
|
||||
|
||||
func (t Tensor) block() (n int) {
|
||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||
return math.MaxInt
|
||||
return -1
|
||||
}
|
||||
|
||||
return
|
||||
@@ -355,24 +288,24 @@ func (t Tensor) blockSize() uint64 {
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
return 1
|
||||
case
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
4, // MXFP4
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -395,6 +328,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeMXFP4, 39:
|
||||
return 1 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
@@ -445,8 +380,6 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -546,14 +479,12 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsArr := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
headsKVArr := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
@@ -563,51 +494,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
|
||||
// Default for models unless special-cased below. These defaults mirror the
|
||||
// cache usage in llama.cpp under the assumption that models without special
|
||||
// cases below will use the llamarunner and caching will be handled by the
|
||||
// llama.cpp layer.
|
||||
//
|
||||
// This also assumes that a layer without heads or headsKV set is recurrent
|
||||
// which is usually the case. Some models (eg nemotronh) use "blocks" in
|
||||
// place of layers where some are MLP blocks that don't have any cache.
|
||||
// Models like this will need a special case below to be accurately
|
||||
// estimated.
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
kvSizeAttn := uint64(0)
|
||||
kvSizeRecurrent := uint64(0)
|
||||
for i := range kv {
|
||||
headsL := headsArr[i]
|
||||
headsKVL := headsKVArr[i]
|
||||
if headsL > 0 && headsKVL > 0 {
|
||||
// full attention layer
|
||||
// NOTE: Assumes uniform values for all attn layers
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
|
||||
kvSizeAttn += kv[i]
|
||||
} else {
|
||||
// recurrent layer
|
||||
ssmDConv := f.KV().SSMConvKernel()
|
||||
ssmDState := f.KV().SSMStateSize()
|
||||
ssmDInner := f.KV().SSMInnerSize()
|
||||
ssmNGroups := f.KV().SSMGroupCount()
|
||||
nEmbdR := uint64(0)
|
||||
if ssmDConv > 0 {
|
||||
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
|
||||
}
|
||||
nEmbdS := ssmDState * ssmDInner
|
||||
|
||||
// recurrent always uses F32 in llama.cpp backend
|
||||
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
|
||||
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
|
||||
|
||||
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
|
||||
kvSizeRecurrent += kv[i]
|
||||
}
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama", "llama4":
|
||||
@@ -785,12 +677,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@@ -865,16 +752,12 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" {
|
||||
return true
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||
slog.Warn("model only supports non-quantized cache types", "model", arch)
|
||||
return false
|
||||
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
||||
return cacheType == "f16"
|
||||
}
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
@@ -884,23 +767,12 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
@@ -908,8 +780,6 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
case "f32":
|
||||
return 4 // f32 (default for recurrent)
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
|
||||
@@ -533,15 +533,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(
|
||||
ts,
|
||||
func(a, b *Tensor) int {
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.block(), b.block()),
|
||||
cmp.Compare(a.Name, b.Name),
|
||||
)
|
||||
},
|
||||
)
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
|
||||
@@ -11,24 +11,24 @@ import (
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
}
|
||||
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
@@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 592,
|
||||
Offset: 608,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
|
||||
@@ -146,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
@@ -174,8 +176,6 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
@@ -191,8 +191,8 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
@@ -226,7 +226,6 @@ const (
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
@@ -319,7 +318,7 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case 4, TensorTypeMXFP4:
|
||||
case TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
|
||||
@@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: libraryEmbedModels[0],
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
@@ -410,99 +410,3 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "qwen3:0.6b"
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlueSky(t *testing.T) {
|
||||
@@ -36,8 +37,8 @@ func TestUnicode(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -49,20 +50,8 @@ func TestUnicode(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
@@ -80,9 +69,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
@@ -97,9 +84,7 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -77,21 +79,21 @@ func TestMultiModelStress(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// All models compatible with ollama-engine
|
||||
smallModels := []string{
|
||||
"llama3.2:1b",
|
||||
"qwen3:0.6b",
|
||||
"gemma2:2b",
|
||||
"deepseek-r1:1.5b", // qwen2 arch
|
||||
"gemma3:270m",
|
||||
"gemma:2b",
|
||||
"deepseek-r1:1.5b",
|
||||
"starcoder2:3b",
|
||||
}
|
||||
mediumModels := []string{
|
||||
"llama3.2:3b", // ~3.4G
|
||||
"qwen3:8b", // ~6.6G
|
||||
"gpt-oss:20b", // ~15G
|
||||
"deepseek-r1:7b", // ~5.6G
|
||||
"gemma3:4b", // ~5.8G
|
||||
"gemma2:9b", // ~8.1G
|
||||
"qwen3:8b",
|
||||
"llama2",
|
||||
"deepseek-r1:7b",
|
||||
"mistral",
|
||||
"dolphin-mistral",
|
||||
"gemma:7b",
|
||||
"codellama:7b",
|
||||
}
|
||||
|
||||
var chosenModels []string
|
||||
@@ -112,16 +114,13 @@ func TestMultiModelStress(t *testing.T) {
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, model := range chosenModels {
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, model))
|
||||
}
|
||||
|
||||
// Determine how many models we can load in parallel before we exceed VRAM
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
chooseModels:
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
@@ -143,13 +142,6 @@ chooseModels:
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Model: "llama2",
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -49,8 +49,8 @@ func TestContextExhaustion(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
@@ -111,56 +111,5 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
@@ -39,14 +38,14 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
@@ -74,8 +73,9 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
@@ -111,8 +111,9 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
@@ -154,135 +155,93 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
type testReq struct {
|
||||
Name string
|
||||
Request api.EmbedRequest
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*api.EmbedResponse, error)
|
||||
}{
|
||||
reqs := []testReq{
|
||||
{
|
||||
name: "target truncation",
|
||||
request: api.EmbedRequest{
|
||||
Name: "Target Truncation",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default truncate",
|
||||
request: api.EmbedRequest{
|
||||
Name: "Default Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "explicit truncate",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Name: "Explicit Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
})
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
res[req.Name] = response
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request and truncate true request to be the same")
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
return client.Embeddings(ctx, &req)
|
||||
response, err := client.Embeddings(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
return client.Embed(ctx, &req)
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVisionModels(t *testing.T) {
|
||||
@@ -31,9 +32,7 @@ func TestVisionModels(t *testing.T) {
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
@@ -53,9 +52,7 @@ func TestVisionModels(t *testing.T) {
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
@@ -65,9 +62,7 @@ func TestVisionModels(t *testing.T) {
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
@@ -89,9 +84,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
47
integration/llm_test.go
Normal file
47
integration/llm_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
@@ -45,9 +45,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
@@ -91,9 +89,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
switch {
|
||||
case genErr == nil:
|
||||
successCount++
|
||||
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||
}
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
@@ -101,9 +97,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
if genErr != nil {
|
||||
t.Fatalf("%d request failed", i)
|
||||
}
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
@@ -114,13 +108,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedwg.Wait()
|
||||
|
||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
if resetByPeerCount != 0 {
|
||||
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||
}
|
||||
if busyCount == 0 {
|
||||
t.Fatalf("no requests hit busy error but some should have")
|
||||
}
|
||||
if canceledCount > 0 {
|
||||
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||
}
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -26,11 +25,11 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
smol = "llama3.2:1b"
|
||||
stream = false
|
||||
smol = "llama3.2:1b"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -436,9 +435,7 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
||||
}
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
@@ -471,9 +468,7 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
@@ -502,22 +497,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -530,17 +509,20 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return context
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
verify()
|
||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for generate")
|
||||
verify()
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return context
|
||||
}
|
||||
@@ -561,7 +543,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
@@ -579,104 +561,17 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[][]string{
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for chat")
|
||||
verify()
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
// think := api.ThinkValue{Value: "low"}
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
// reqs[i].Think = &think
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
@@ -684,39 +579,6 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
gpuPercent = 0
|
||||
case m.SizeVRAM == m.Size:
|
||||
gpuPercent = 100
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
|
||||
@@ -378,7 +378,9 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor
|
||||
|
||||
3
llama/llama.cpp/src/llama-context.cpp
vendored
3
llama/llama.cpp/src/llama-context.cpp
vendored
@@ -962,7 +962,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const bool output_all = false;
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
|
||||
@@ -515,34 +515,33 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
||||
}
|
||||
nChunks := C.mtmd_input_chunks_size(ic)
|
||||
numEmbed := llamaContext.Model().NEmbd()
|
||||
embed := make([][]float32, 0)
|
||||
lastChunkSize := 0
|
||||
for i := range int(nChunks) {
|
||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||
lastChunkSize = numTokens
|
||||
|
||||
// Encode the chunk
|
||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||
return nil, errors.New("unable to encode mtmd image chunk")
|
||||
}
|
||||
|
||||
// Get the embeddings for this chunk
|
||||
chunkEmbed := make([][]float32, numTokens)
|
||||
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == chunkEmbd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range numTokens {
|
||||
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
embed = append(embed, chunkEmbed...)
|
||||
}
|
||||
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||
|
||||
// Get the embeddings
|
||||
embed := make([][]float32, lastChunkSize)
|
||||
embd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == embd {
|
||||
return nil, errors.New("failed to get image embedding")
|
||||
}
|
||||
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range lastChunkSize {
|
||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ checks.
|
||||
1 file changed, 18 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 57eae461..c7f9dc3a 100644
|
||||
index 57eae461..9db0c8b5 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2671,12 +2671,24 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Mon, 18 Aug 2025 16:58:39 -0700
|
||||
Subject: [PATCH] decode: disable output_all
|
||||
|
||||
---
|
||||
src/llama-context.cpp | 3 +--
|
||||
1 file changed, 1 insertion(+), 2 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index 26a5cf9c..6ece5263 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
- // when computing embeddings, all tokens are output
|
||||
- const bool output_all = cparams.embeddings;
|
||||
+ const bool output_all = false;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
@@ -1,130 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Wed, 27 Aug 2025 14:39:48 -0700
|
||||
Subject: [PATCH] ggml: Enable resetting backend devices
|
||||
|
||||
Touching a CUDA device causes the allocation of a primary context
|
||||
with CUDA data structures (~300 MB of VRAM). If a device is
|
||||
unused then it can be reset to free these data structures.
|
||||
---
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-backend-impl.h | 4 ++++
|
||||
ggml/src/ggml-backend.cpp | 8 ++++++++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 1 +
|
||||
5 files changed, 29 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index b602a7c78..fda5ceb24 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -167,6 +167,7 @@ extern "C" {
|
||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
||||
index 81749a5a3..6f10c353b 100644
|
||||
--- a/ggml/src/ggml-backend-impl.h
|
||||
+++ b/ggml/src/ggml-backend-impl.h
|
||||
@@ -178,6 +178,10 @@ extern "C" {
|
||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
+
|
||||
+ // (optional) reset device, clearing existing allocations and context
|
||||
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
+ void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||
index 05a842ed5..6556943b0 100644
|
||||
--- a/ggml/src/ggml-backend.cpp
|
||||
+++ b/ggml/src/ggml-backend.cpp
|
||||
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
+ if (device->iface.reset == NULL) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ device->iface.reset(device);
|
||||
+}
|
||||
+
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index c7f9dc3a5..e43fde523 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
+void ggml_cuda_reset_device(int device) {
|
||||
+ ggml_cuda_set_device(device);
|
||||
+ CUDA_CHECK(cudaDeviceReset());
|
||||
+}
|
||||
+
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
+
|
||||
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
+ props->memory_total = props->memory_free = 0;
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
+ ggml_cuda_reset_device(ctx->device);
|
||||
+}
|
||||
+
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
+ /* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
- ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index c31f31923..cf22e60d2 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -40,6 +40,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
+#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
@@ -1,28 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 29 Aug 2025 16:53:08 -0700
|
||||
Subject: [PATCH] harden uncaught exception registration
|
||||
|
||||
---
|
||||
ggml/src/ggml.cpp | 8 ++++++--
|
||||
1 file changed, 6 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
|
||||
index 0d388d45..f5bcb446 100644
|
||||
--- a/ggml/src/ggml.cpp
|
||||
+++ b/ggml/src/ggml.cpp
|
||||
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
||||
return false;
|
||||
}
|
||||
const auto prev{std::get_terminate()};
|
||||
- GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
- previous_terminate_handler = prev;
|
||||
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
+ if (prev != ggml_uncaught_exception) {
|
||||
+ previous_terminate_handler = prev;
|
||||
+ } else {
|
||||
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||
+ }
|
||||
std::set_terminate(ggml_uncaught_exception);
|
||||
return true;
|
||||
}();
|
||||
@@ -30,7 +30,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
|
||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
||||
gpuSubset := sgl[:numGPUs]
|
||||
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
||||
ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
||||
|
||||
if ok {
|
||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
||||
@@ -48,7 +48,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
|
||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||
|
||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
||||
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
||||
if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading",
|
||||
"model", modelPath,
|
||||
"library", sgl[0].Library,
|
||||
@@ -71,7 +71,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
||||
_, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
@@ -81,7 +81,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
|
||||
}
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
@@ -97,10 +97,6 @@ func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
return false, estimatedVRAM
|
||||
}
|
||||
@@ -195,19 +191,17 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
if envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention() {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if f.SupportsKVCacheType(requested) {
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
|
||||
@@ -148,11 +148,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
var textProcessor model.TextProcessor
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
@@ -165,6 +161,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
}
|
||||
}
|
||||
|
||||
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
|
||||
if newEstimates {
|
||||
slog.Info("enabling new memory estimates")
|
||||
}
|
||||
|
||||
// Verify the requested context size is <= the model training size
|
||||
trainCtx := f.KV().ContextLength()
|
||||
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
|
||||
@@ -172,8 +173,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
opts.NumCtx = int(trainCtx)
|
||||
}
|
||||
|
||||
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
|
||||
|
||||
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
||||
|
||||
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
||||
@@ -196,11 +195,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||
// that can handle it.
|
||||
fa := envconfig.FlashAttention()
|
||||
if f.FlashAttention() {
|
||||
slog.Info("model wants flash attention")
|
||||
fa = true
|
||||
}
|
||||
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
@@ -219,7 +213,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
if kvct != "" && f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
@@ -361,28 +355,23 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
|
||||
envWorkarounds := []string{}
|
||||
envWorkarounds := [][2]string{}
|
||||
for _, gpu := range gpus {
|
||||
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
|
||||
}
|
||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||
envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...)
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path variable with our adjusted version
|
||||
pathNeeded := true
|
||||
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||
for i := range s.cmd.Env {
|
||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if len(envWorkarounds) != 0 {
|
||||
for j, kv := range envWorkarounds {
|
||||
tmp := strings.SplitN(kv, "=", 2)
|
||||
if strings.EqualFold(cmp[0], tmp[0]) {
|
||||
s.cmd.Env[i] = kv
|
||||
envWorkaroundDone[j] = true
|
||||
for _, kv := range envWorkarounds {
|
||||
if strings.EqualFold(cmp[0], kv[0]) {
|
||||
s.cmd.Env[i] = kv[0] + "=" + kv[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -390,11 +379,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
if pathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
for i, done := range envWorkaroundDone {
|
||||
if !done {
|
||||
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("starting runner", "cmd", s.cmd)
|
||||
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
|
||||
@@ -432,7 +416,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
if newEstimates {
|
||||
return &ollamaServer{llmServer: s}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
@@ -508,7 +492,6 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
||||
if !requireFull {
|
||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
} else {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
}
|
||||
@@ -541,6 +524,10 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
||||
}
|
||||
}
|
||||
|
||||
if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory {
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
|
||||
slog.Info("offload", "", s.estimate)
|
||||
|
||||
s.gpus = gpus
|
||||
@@ -664,9 +651,7 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
||||
if !success {
|
||||
s.initModel(ctx, LoadRequest{}, LoadOperationClose)
|
||||
}
|
||||
if s.mem != nil {
|
||||
s.mem.Log(slog.LevelInfo)
|
||||
}
|
||||
s.mem.Log(slog.LevelInfo)
|
||||
}()
|
||||
|
||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||
@@ -679,12 +664,8 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
||||
|
||||
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID,
|
||||
"available", format.HumanBytes2(available),
|
||||
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
@@ -866,7 +847,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
}
|
||||
layers[i] += memory.CPU.Weights[i].Size
|
||||
layers[i] += memory.CPU.Cache[i].Size
|
||||
logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
||||
}
|
||||
|
||||
gpuLayers := ml.GPULayersList{}
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
const LevelTrace slog.Level = -8
|
||||
@@ -30,19 +27,3 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
type key string
|
||||
|
||||
func Trace(msg string, args ...any) {
|
||||
TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
|
||||
}
|
||||
|
||||
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||
if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
|
||||
skip, _ := ctx.Value(key("skip")).(int)
|
||||
pc, _, _, _ := runtime.Caller(1 + skip)
|
||||
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
|
||||
record.Add(args...)
|
||||
logger.Handler().Handle(ctx, record)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
// InputsWeights are always located on the CPU and cannot be moved
|
||||
InputWeights Memory
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
@@ -372,7 +372,6 @@ type Context interface {
|
||||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
@@ -397,13 +396,10 @@ type Tensor interface {
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
SetValueFromIntSlice(s []int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
@@ -416,7 +412,6 @@ type Tensor interface {
|
||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
@@ -430,13 +425,12 @@ type Tensor interface {
|
||||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
QuickGELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
|
||||
@@ -82,7 +82,6 @@ type Backend struct {
|
||||
// to the name that is used by the model definition
|
||||
tensorLoadTargets map[string][]string
|
||||
|
||||
schedMu sync.Mutex // Only one Compute can run at a time
|
||||
sched C.ggml_backend_sched_t
|
||||
schedBackends []C.ggml_backend_t
|
||||
schedBufts []C.ggml_backend_buffer_type_t
|
||||
@@ -271,7 +270,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||
if layer == -1 {
|
||||
@@ -378,7 +377,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
for bs := range maps.Values(bbs) {
|
||||
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
}
|
||||
|
||||
@@ -536,7 +535,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
const BS = 17 // MXFP4 block size
|
||||
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
||||
var s uint64
|
||||
var tmp [16]byte
|
||||
for s < t.Size() {
|
||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||
if err := ctx.Err(); err != nil {
|
||||
@@ -548,13 +546,37 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
return err
|
||||
}
|
||||
for j := range n / BS {
|
||||
for i := 1; i < 9; i++ {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[j*BS+i], bts[j*BS+i+8]
|
||||
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
|
||||
for i := 1; i < BS; i++ {
|
||||
// swap nibbles
|
||||
t_lo := bts[j*BS+i] & 0x0F
|
||||
t_hi := bts[j*BS+i] & 0xF0
|
||||
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
|
||||
}
|
||||
// transform aaaa...bbbb... to abababab...
|
||||
oi := 0
|
||||
tmp := [16]byte{}
|
||||
for i := 1; i < 9; i++ {
|
||||
blk_a0 := bts[j*BS+i] & 0xF0
|
||||
blk_a1 := bts[j*BS+i] << 4
|
||||
blk_b0 := bts[j*BS+i+8] >> 4
|
||||
blk_b1 := bts[j*BS+i+8] & 0x0F
|
||||
// swap once more
|
||||
out0 := blk_a0 | blk_b0
|
||||
out1 := blk_a1 | blk_b1
|
||||
out_h0 := out0 & 0xF0
|
||||
out_l0 := out0 & 0x0F
|
||||
out_h1 := out1 & 0xF0
|
||||
out_l1 := out1 & 0x0F
|
||||
out0 = (out_h0 >> 4) | (out_l0 << 4)
|
||||
out1 = (out_h1 >> 4) | (out_l1 << 4)
|
||||
tmp[oi] = out0
|
||||
oi++
|
||||
tmp[oi] = out1
|
||||
oi++
|
||||
}
|
||||
for i := range tmp {
|
||||
bts[j*BS+i+1] = tmp[i]
|
||||
}
|
||||
copy(bts[j*BS+1:j*BS+17], tmp[:])
|
||||
}
|
||||
|
||||
for _, tt := range tts {
|
||||
@@ -630,18 +652,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Cleanup any backend state from devices that we didn't end up using
|
||||
nextDevice:
|
||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||
for _, backend := range b.schedBackends {
|
||||
if d == C.ggml_backend_get_device(backend) {
|
||||
continue nextDevice
|
||||
}
|
||||
}
|
||||
|
||||
C.ggml_backend_dev_reset(d)
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -759,15 +769,6 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
c.ComputeWithNotify(nil, tensors...)
|
||||
}
|
||||
|
||||
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
||||
c.b.schedMu.Lock()
|
||||
defer c.b.schedMu.Unlock()
|
||||
if cb != nil {
|
||||
go cb()
|
||||
}
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
@@ -811,7 +812,7 @@ func (c *Context) Reserve() {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
}
|
||||
|
||||
@@ -842,7 +843,23 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
|
||||
cdtype := ggmlDType(dtype)
|
||||
var cdtype uint32
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
cdtype = C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
cdtype = C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
cdtype = C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
cdtype = C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
cdtype = C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
cdtype = C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
@@ -1020,12 +1037,6 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) SetValueFromIntSlice(s []int32) {
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
@@ -1045,32 +1056,6 @@ func (t *Tensor) DType() ml.DType {
|
||||
}
|
||||
}
|
||||
|
||||
func ggmlDType(dtype ml.DType) uint32 {
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
return C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
return C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
return C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
return C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
return C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
return C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1205,13 +1190,6 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
@@ -1431,46 +1409,35 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
||||
|
||||
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -167,7 +167,6 @@ extern "C" {
|
||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
@@ -178,10 +178,6 @@ extern "C" {
|
||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
|
||||
// (optional) reset device, clearing existing allocations and context
|
||||
// the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
|
||||
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
@@ -477,14 +477,6 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
if (device->iface.reset == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
device->iface.reset(device);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
|
||||
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -103,11 +103,6 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
void ggml_cuda_reset_device(int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
CUDA_CHECK(cudaDeviceReset());
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
@@ -3248,10 +3243,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
|
||||
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
props->memory_total = props->memory_free = 0;
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
@@ -3708,11 +3700,6 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_reset_device(ctx->device);
|
||||
}
|
||||
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
@@ -3729,7 +3716,6 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
/* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
@@ -3849,6 +3835,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
@@ -40,7 +40,6 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
|
||||
8
ml/backend/ggml/ggml/src/ggml.cpp
vendored
8
ml/backend/ggml/ggml/src/ggml.cpp
vendored
@@ -19,12 +19,8 @@ static bool ggml_uncaught_exception_init = []{
|
||||
return false;
|
||||
}
|
||||
const auto prev{std::get_terminate()};
|
||||
// GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
if (prev != ggml_uncaught_exception) {
|
||||
previous_terminate_handler = prev;
|
||||
} else {
|
||||
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||
}
|
||||
GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
previous_terminate_handler = prev;
|
||||
std::set_terminate(ggml_uncaught_exception);
|
||||
return true;
|
||||
}();
|
||||
|
||||
@@ -26,7 +26,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
@@ -40,7 +39,6 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
case TypeCLS:
|
||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||
case TypeLast:
|
||||
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||
return hiddenStates
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package pooling_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/discover"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
)
|
||||
|
||||
func setup(tb testing.TB, n int) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
||||
"general.architecture": "test",
|
||||
"test.block_count": uint32(1),
|
||||
}, []*fsggml.Tensor{
|
||||
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
||||
}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
var gpuLayers ml.GPULayersList
|
||||
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
|
||||
gpuLayers = append(gpuLayers, ml.GPULayers{
|
||||
ID: gpus[0].ID,
|
||||
Layers: slices.Collect(func(yield func(int) bool) {
|
||||
for i := range n {
|
||||
if !yield(i) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
cases := map[pooling.Type][]float32{
|
||||
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
||||
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
||||
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
||||
}
|
||||
for typ, want := range cases {
|
||||
t.Run(typ.String(), func(t *testing.T) {
|
||||
b := setup(t, 99)
|
||||
defer b.Close()
|
||||
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
||||
tt = typ.Forward(ctx, tt)
|
||||
|
||||
ctx.Forward(tt).Compute(tt)
|
||||
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
@@ -108,7 +109,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
case r >= 0x007e && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
@@ -201,11 +202,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -241,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -207,36 +207,6 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
|
||||
@@ -54,9 +54,10 @@ type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
@@ -65,8 +66,7 @@ type Batch struct {
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -20,15 +22,10 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
@@ -67,7 +64,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
@@ -107,12 +104,19 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := modelForArch(b.Config())
|
||||
arch := b.Config().Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
|
||||
m, err := f(b.Config())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
@@ -124,38 +128,30 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
meta, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
m, err := modelForArch(meta.KV())
|
||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
m, err := f(kv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func modelForArch(c fs.Config) (Model, error) {
|
||||
arch := c.Architecture()
|
||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedModel
|
||||
}
|
||||
|
||||
return f(c)
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
@@ -202,7 +198,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
@@ -243,7 +239,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = reflect.Indirect(vv)
|
||||
vv = vv.Elem()
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
@@ -282,7 +278,7 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
@@ -291,6 +287,8 @@ func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
@@ -304,7 +302,7 @@ func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestParseTags(t *testing.T) {
|
||||
@@ -147,58 +148,39 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelForArch(t *testing.T) {
|
||||
type fakeModel struct {
|
||||
Model
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fsggml.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
|
||||
type fakeEmbeddingModel struct {
|
||||
Model
|
||||
models["dummy"] = func(fs.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
}
|
||||
|
||||
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
|
||||
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
config fs.Config
|
||||
want any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "model",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "model",
|
||||
},
|
||||
want: fakeModel{},
|
||||
},
|
||||
{
|
||||
name: "embedding",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "model",
|
||||
"model.pooling_type": uint32(1),
|
||||
},
|
||||
want: fakeEmbeddingModel{},
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "unsupported",
|
||||
},
|
||||
err: ErrUnsupportedModel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := modelForArch(tt.config)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
}
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Backend() ml.Backend {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Config() config {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
package bert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||
|
||||
Layers []EncoderLayer `gguf:"blk"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
|
||||
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
// Attention
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// MLP
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if a.QueryNorm != nil {
|
||||
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
}
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if a.KeyNorm != nil {
|
||||
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
}
|
||||
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength int
|
||||
poolingType pooling.Type
|
||||
eps float32
|
||||
normalize bool
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
//nolint:misspell
|
||||
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||
// support it for compatibility.
|
||||
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("bert", New)
|
||||
model.Register("bert_embed", New)
|
||||
}
|
||||
@@ -24,7 +24,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
model.SentencePieceModel
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -40,7 +40,7 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
attnValLen: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||
},
|
||||
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -138,7 +138,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -176,6 +176,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
@@ -192,7 +193,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
model.SentencePieceModel
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -141,11 +141,12 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
|
||||
@@ -53,10 +53,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: 1,
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -98,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -116,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -126,7 +123,7 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -162,10 +159,8 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
@@ -196,12 +191,12 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
@@ -170,7 +170,8 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
||||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx, perLayerInput)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
@@ -256,14 +257,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
@@ -291,7 +292,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
|
||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx, upStates)
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
@@ -349,7 +350,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||
|
||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.TransformerBlocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
|
||||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||
}
|
||||
|
||||
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
|
||||
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
@@ -2,6 +2,7 @@ package llama
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -22,60 +23,51 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
// TODO: support mixtures of experts
|
||||
return nil, model.ErrUnsupportedModel
|
||||
// This model currently only supports the gpt2 tokenizer
|
||||
if c.String("tokenizer.ggml.model") == "llama" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
// Best effort detection of library/deepseek-coder model(s) which are incompatible
|
||||
if c.String("general.name") == "deepseek-ai" {
|
||||
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
|
||||
}
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&vocabulary,
|
||||
)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 1e5),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -106,8 +98,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -116,7 +108,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -126,7 +118,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -168,10 +160,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*Projector `gguf:"mm"`
|
||||
*TextModel
|
||||
}
|
||||
@@ -134,16 +134,16 @@ type separator struct {
|
||||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
var imageInputs []*input.Input
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
} else {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,7 +176,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if useRope {
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
|
||||
if opts.useQKNorm {
|
||||
@@ -58,14 +58,14 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.Forward(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
|
||||
}
|
||||
|
||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
|
||||
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
|
||||
@@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, &input.Input{Token: 13})
|
||||
result = append(result, input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, &input.Input{Token: 12})
|
||||
result = append(result, input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -65,7 +65,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ type VisionMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.0"`
|
||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal != nil {
|
||||
inputs[i].Token = 128256 // <|image|>
|
||||
@@ -107,9 +107,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@@ -58,7 +58,7 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
|
||||
@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
value := attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -59,7 +59,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
result = append(result, input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, &input.Input{
|
||||
result = append(result, input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,8 +140,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
|
||||
originalContextLength: int(c.Uint("context_length", 128000)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
@@ -90,7 +90,7 @@ type MLP struct {
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -100,7 +100,8 @@ type VisionMLP struct {
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
||||
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates, err := m.forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
for i := range layers {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
Model: &Model{
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
},
|
||||
},
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
@@ -65,10 +65,10 @@ type MLP interface {
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
@@ -87,9 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.SILU(ctx)
|
||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||
|
||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
@@ -107,8 +111,7 @@ type dense struct {
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
|
||||
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -151,39 +154,29 @@ type Model struct {
|
||||
*Options
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates, err := m.forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
}
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
@@ -223,7 +216,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
@@ -237,5 +230,4 @@ func New(c fs.Config) (model.Model, error) {
|
||||
func init() {
|
||||
model.Register("qwen3", New)
|
||||
model.Register("qwen3moe", New)
|
||||
model.Register("qwen3_embed", newEmbed)
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type Parser interface {
|
||||
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||
HasToolSupport() bool
|
||||
HasThinkingSupport() bool
|
||||
}
|
||||
|
||||
func ParserForName(name string) Parser {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
|
||||
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
return s, "", nil, nil
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasToolSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwenParserState int
|
||||
|
||||
const (
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
const (
|
||||
qwenParserState_LookingForToolStart qwenParserState = iota
|
||||
qwenParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.acc.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwenEventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here. See the note below about
|
||||
// `qwenEvent`s for more details
|
||||
sb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
|
||||
var all []qwenEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwenEvent
|
||||
events, keepLooping = eat(p)
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// we use some internal event types in order to communicate between `Add` and
|
||||
// `eat`. We do this to support interleaving content and parallel tool calls in
|
||||
// the parser, even though qwen3-coder isn't supposed to do this. Our API
|
||||
// doesn't currently support models outputting multiple messages in a turn, so
|
||||
// we wouldn't be able to represent it yet, but there's no reason to prevent the
|
||||
// parser from supporting it, especially for future models if they end up using
|
||||
// a similar format.
|
||||
type qwenEvent interface {
|
||||
isQwenEvent()
|
||||
}
|
||||
|
||||
type qwenEventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
type qwenEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwenEventContent) isQwenEvent() {}
|
||||
func (qwenEventRawToolCall) isQwenEvent() {}
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. If the parser transitions to another
|
||||
// state, it may have additional events to emit on the next call, which is what
|
||||
// the second return value indicates
|
||||
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
var events []qwenEvent
|
||||
|
||||
switch p.state {
|
||||
case qwenParserState_LookingForToolStart:
|
||||
if strings.Contains(p.acc.String(), toolOpenTag) {
|
||||
// we found a full tool open tag, so we can emit the content before the
|
||||
// tag, being sure to trim any trailing whitespace
|
||||
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
p.state = qwenParserState_CollectingToolContent
|
||||
return events, true
|
||||
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
|
||||
// we found a partial tool open tag, so we can emit the unambiguous part,
|
||||
// which is the (trailing-whitespace trimmed) content before the partial
|
||||
// tool open tag
|
||||
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
return events, false
|
||||
} else {
|
||||
// we found content that is entirely not a tool call. We should withhold
|
||||
// any trailing whitespace in case this is the end of the content
|
||||
whitespaceLen := trailingWhitespaceLen(p.acc.String())
|
||||
ambiguousStart := len(p.acc.String()) - whitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case qwenParserState_CollectingToolContent:
|
||||
if strings.Contains(p.acc.String(), toolCloseTag) {
|
||||
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
// remove any whitespace between the tool call and any content after it
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
events = append(events, qwenEventRawToolCall{raw: before})
|
||||
p.state = qwenParserState_LookingForToolStart
|
||||
return events, true
|
||||
} else {
|
||||
// note that we don't need to check the overlap here because we only plan
|
||||
// on parsing the tool call once we see the full closing tag. We don't
|
||||
// stream back the unparsed tool content, so there's no need to be eager
|
||||
// here
|
||||
return events, false
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if !unicode.IsSpace(rune(s[i])) {
|
||||
return len(s) - i - 1
|
||||
}
|
||||
}
|
||||
return len(s)
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
Parameters []XMLParameter `xml:"parameter"`
|
||||
}
|
||||
|
||||
type XMLParameter struct {
|
||||
Name string `xml:"name,attr"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// parseToolCall parses a raw tool call string into an api.ToolCall.
|
||||
// The raw string follows an xml-like format, here's an example:
|
||||
//
|
||||
// <function=get_current_temperature>
|
||||
// <parameter=location>
|
||||
// San Francisco
|
||||
// </parameter>
|
||||
// <parameter=unit>
|
||||
// celsius
|
||||
// </parameter>
|
||||
// </function>
|
||||
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
xmlString := transformToXML(raw.raw)
|
||||
|
||||
var functionCall XMLFunctionCall
|
||||
err := xml.Unmarshal([]byte(xmlString), &functionCall)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
toolCall.Function = api.ToolCallFunction{
|
||||
Name: functionCall.Name,
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionCall.Name {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
for _, parameter := range functionCall.Parameters {
|
||||
// Look up the parameter type if we found the tool
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
|
||||
//
|
||||
// For union types (multiple types in PropertyType, which we support but doesn't
|
||||
// seem as though the reference parser does type coercion with those types in
|
||||
// mind) we use a type precedence approach:
|
||||
// 1. null - checked first regardless of declared types (matches reference implementation)
|
||||
// 2. boolean - only "true"/"false" are valid booleans
|
||||
// 3. integer - must parse as a whole number
|
||||
// 4. number - must parse as numeric (returns int if no decimal part)
|
||||
// 5. array - must parse as valid JSON array
|
||||
// 6. object - must parse as valid JSON object
|
||||
// 7. string - always succeeds (least specific type)
|
||||
//
|
||||
// This precedence ensures we return the most specific type that successfully parses,
|
||||
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
|
||||
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
|
||||
func parseValue(raw string, paramType api.PropertyType) any {
|
||||
// first remove a single leading newlines, and a single trailing newline (if
|
||||
// they exist). This follows the reference implementation
|
||||
raw = strings.TrimPrefix(raw, "\n")
|
||||
raw = strings.TrimSuffix(raw, "\n")
|
||||
|
||||
// Check for null first (case-insensitive) - this takes precedence over any type
|
||||
if strings.ToLower(raw) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If no type is specified, default to string
|
||||
if len(paramType) == 0 {
|
||||
return raw
|
||||
}
|
||||
|
||||
// Check if any of the specified types match, using type precedence
|
||||
// Order: boolean -> integer -> number -> array -> object -> string
|
||||
typeSet := make(map[string]bool)
|
||||
for _, t := range paramType {
|
||||
typeSet[t] = true
|
||||
}
|
||||
|
||||
// Try boolean first (most restrictive)
|
||||
if typeSet["boolean"] {
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
}
|
||||
// If not a valid boolean but boolean is the only type, return false (matching reference)
|
||||
if len(paramType) == 1 {
|
||||
return false
|
||||
}
|
||||
// Otherwise try other types
|
||||
}
|
||||
|
||||
// Try integer
|
||||
if typeSet["integer"] {
|
||||
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
// Return as int if it fits in int32, otherwise int64
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
// If integer is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try number (float)
|
||||
if typeSet["number"] {
|
||||
if f, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||
// If the number has no decimal part, return as int (matching reference)
|
||||
if f == math.Trunc(f) {
|
||||
i := int64(f)
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
return f
|
||||
}
|
||||
// If number is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try array
|
||||
if typeSet["array"] {
|
||||
var arr []interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
||||
return arr
|
||||
}
|
||||
// If array is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try object
|
||||
if typeSet["object"] {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||
return obj
|
||||
}
|
||||
// If object is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// String always succeeds (or if "string" is in the type set)
|
||||
if typeSet["string"] {
|
||||
return raw
|
||||
}
|
||||
|
||||
// If we get here, none of the types matched and string wasn't an option
|
||||
// We return string as a fallback. The reference implementation will attempt
|
||||
// to parse the value as a python literal, but we purposefully don't support
|
||||
// that
|
||||
return raw
|
||||
}
|
||||
|
||||
var (
|
||||
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||
)
|
||||
|
||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||
// xml so that it can be parsed by any xml parser
|
||||
func transformToXML(raw string) string {
|
||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||
// care to properly escape the string that becomes the attribute value
|
||||
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||
tag := groups[1]
|
||||
var escapedValue strings.Builder
|
||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||
})
|
||||
|
||||
// Walk the resulting string, escaping any character data that sits between the
|
||||
// xml tags we just emitted
|
||||
var out strings.Builder
|
||||
lastIdx := 0
|
||||
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||
if loc[0] > lastIdx {
|
||||
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||
}
|
||||
out.WriteString(transformed[loc[0]:loc[1]])
|
||||
lastIdx = loc[1]
|
||||
}
|
||||
if lastIdx < len(transformed) {
|
||||
escapeTextNode(&out, transformed[lastIdx:])
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
// escapeTextNode escapes XML character data without altering other characters
|
||||
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||
func escapeTextNode(sb *strings.Builder, s string) {
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '&':
|
||||
sb.WriteString("&")
|
||||
case '<':
|
||||
sb.WriteString("<")
|
||||
case '>':
|
||||
sb.WriteString(">")
|
||||
default:
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,878 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// tool creates a test tool with the given name and properties
|
||||
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||
t.Function.Parameters.Type = "object"
|
||||
t.Function.Parameters.Properties = props
|
||||
return t
|
||||
}
|
||||
|
||||
func TestQwenParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []qwenEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed word by word",
|
||||
steps: []step{
|
||||
{
|
||||
input: "hi",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "hi there<tool_call>",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in one message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "before1<tool_call>in tool call</tool_call>after1<tool_call>in tool call 2</tool_call>after2",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "before1"},
|
||||
qwenEventRawToolCall{raw: "in tool call"},
|
||||
qwenEventContent{content: "after1"},
|
||||
qwenEventRawToolCall{raw: "in tool call 2"},
|
||||
qwenEventContent{content: "after2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool calls with split tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: "before<tool",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "before"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "_call>in tool call</tool",
|
||||
wantEvents: []qwenEvent{},
|
||||
},
|
||||
{
|
||||
input: "_call>af",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "in tool call"},
|
||||
qwenEventContent{content: "af"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ter",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "ter"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "abc\n<tool_call>def</tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "abc"},
|
||||
qwenEventRawToolCall{raw: "def"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<tool_call>abc</tool_call>\ndef",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "abc"},
|
||||
qwenEventContent{content: "def"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "\n<tool_call>abc</tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "abc"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "abc\n<tool_call",
|
||||
wantEvents: []qwenEvent{
|
||||
// \n should not be emitted yet because `<tool_call` might be a tool
|
||||
// open tag, in which case the whitespace should be trimmed
|
||||
qwenEventContent{content: "abc"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "\n<tool_call fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "token-by-token whitespace handling",
|
||||
steps: []step{
|
||||
{
|
||||
input: "a",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "a"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "\n",
|
||||
wantEvents: []qwenEvent{},
|
||||
},
|
||||
{
|
||||
input: "b",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "\nb"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.acc.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenToolParser(t *testing.T) {
|
||||
type step struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
steps := []step{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_temperature",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "names with spaces",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=get current temperature>
|
||||
<parameter=location with spaces>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit with spaces>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// this mirrors the reference implementation's behavior, but unclear if it
|
||||
// ever happens. If so, then we should probably remove them instead, this
|
||||
// test is to just document the current behavior and test that we don't get
|
||||
// xml errors
|
||||
{
|
||||
name: "names with quotes",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function="get current temperature">
|
||||
<parameter="location with spaces">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter="unit with spaces">
|
||||
"celsius"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `<function=calculate>
|
||||
<parameter=x>
|
||||
3.14
|
||||
</parameter>
|
||||
<parameter=y>
|
||||
42
|
||||
</parameter>
|
||||
<parameter=enabled>
|
||||
true
|
||||
</parameter>
|
||||
<parameter=items>
|
||||
["a", "b", "c"]
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: map[string]any{
|
||||
"x": 3.14,
|
||||
"y": 42,
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// regression test for <https://github.com/ollama/ollama/issues/12357>
|
||||
{
|
||||
name: "ampersands in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=exec>
|
||||
<parameter=command>
|
||||
ls && echo "done"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "angle brackets in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=exec>
|
||||
<parameter=command>
|
||||
ls && echo "a > b and a < b"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, step := range steps {
|
||||
gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
|
||||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
raw string
|
||||
paramType api.PropertyType
|
||||
want any
|
||||
}{
|
||||
{
|
||||
desc: "default string value (no type specified)",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "some-string",
|
||||
want: "some-string",
|
||||
},
|
||||
{
|
||||
desc: "trim a single leading and trailing newline",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "\nsome-string\n",
|
||||
want: "some-string",
|
||||
},
|
||||
{
|
||||
desc: "trim at most one leading and trailing newline",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "\n\nsome-string\n\n",
|
||||
want: "\nsome-string\n",
|
||||
},
|
||||
{
|
||||
desc: "newline really has to be the first character to be trimmed",
|
||||
paramType: api.PropertyType{},
|
||||
raw: " \nsome-string\n ",
|
||||
want: " \nsome-string\n ",
|
||||
},
|
||||
{
|
||||
desc: "numeric type",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
// Integer parsing tests
|
||||
{
|
||||
desc: "integer type",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "42",
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
desc: "negative integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "-100",
|
||||
want: -100,
|
||||
},
|
||||
{
|
||||
desc: "zero integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "0",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
desc: "integer with leading zeros",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "007",
|
||||
want: 7,
|
||||
},
|
||||
{
|
||||
desc: "large integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "2147483648", // Just beyond int32 max
|
||||
want: int64(2147483648),
|
||||
},
|
||||
// Float/number parsing tests
|
||||
{
|
||||
desc: "float type",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.14",
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
desc: "negative float",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "-273.15",
|
||||
want: -273.15,
|
||||
},
|
||||
{
|
||||
desc: "float without decimal part",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "100.0",
|
||||
want: 100,
|
||||
},
|
||||
{
|
||||
desc: "scientific notation positive",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "1.23e5",
|
||||
want: 123000, // Will be int since it has no decimal part
|
||||
},
|
||||
{
|
||||
desc: "scientific notation negative",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "1.5e-3",
|
||||
want: 0.0015,
|
||||
},
|
||||
{
|
||||
desc: "very small float",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "0.00000001",
|
||||
want: 0.00000001,
|
||||
},
|
||||
// String parsing tests
|
||||
{
|
||||
desc: "explicit string type",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
desc: "string with special characters",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "/usr/local/bin/test-file_v2.0.sh",
|
||||
want: "/usr/local/bin/test-file_v2.0.sh",
|
||||
},
|
||||
{
|
||||
desc: "string with quotes",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: `He said "hello" to me`,
|
||||
want: `He said "hello" to me`,
|
||||
},
|
||||
{
|
||||
desc: "multiline string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "line one\nline two\nline three",
|
||||
want: "line one\nline two\nline three",
|
||||
},
|
||||
{
|
||||
desc: "empty string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
desc: "string that looks like a number",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "12345",
|
||||
want: "12345",
|
||||
},
|
||||
// Boolean parsing tests
|
||||
{
|
||||
desc: "boolean true",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "false",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "boolean case insensitive true",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "True",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean case insensitive false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "FALSE",
|
||||
want: false,
|
||||
},
|
||||
// Null parsing tests
|
||||
{
|
||||
desc: "null value lowercase",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "null value case insensitive",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "NULL",
|
||||
want: nil,
|
||||
},
|
||||
// Array parsing tests
|
||||
{
|
||||
desc: "array of strings",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `["foo", "bar", "baz"]`,
|
||||
want: []any{"foo", "bar", "baz"},
|
||||
},
|
||||
{
|
||||
desc: "array of numbers",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `[1, 2.5, 3]`,
|
||||
want: []any{float64(1), 2.5, float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "array of mixed types",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `["string", 123, true, null]`,
|
||||
want: []any{"string", float64(123), true, nil},
|
||||
},
|
||||
{
|
||||
desc: "empty array",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `[]`,
|
||||
want: []any{},
|
||||
},
|
||||
// Object parsing tests
|
||||
{
|
||||
desc: "simple object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"key": "value", "number": 42}`,
|
||||
want: map[string]any{"key": "value", "number": float64(42)},
|
||||
},
|
||||
{
|
||||
desc: "nested object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"outer": {"inner": "value"}}`,
|
||||
want: map[string]any{"outer": map[string]any{"inner": "value"}},
|
||||
},
|
||||
{
|
||||
desc: "empty object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{}`,
|
||||
want: map[string]any{},
|
||||
},
|
||||
// Error cases and fallback behavior
|
||||
{
|
||||
desc: "invalid integer falls back to string",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "not-a-number",
|
||||
want: "not-a-number",
|
||||
},
|
||||
{
|
||||
desc: "invalid float falls back to string",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.14.159",
|
||||
want: "3.14.159",
|
||||
},
|
||||
{
|
||||
desc: "invalid boolean falls back to false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "yes",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid JSON array falls back to string",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: "[1, 2, unclosed",
|
||||
want: "[1, 2, unclosed",
|
||||
},
|
||||
{
|
||||
desc: "invalid JSON object falls back to string",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"key": unclosed`,
|
||||
want: `{"key": unclosed`,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
desc: "integer overflow should use int64",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "2147483648", // Beyond int32 max
|
||||
want: int64(2147483648),
|
||||
},
|
||||
{
|
||||
desc: "float with many decimal places",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.141592653589793",
|
||||
want: 3.141592653589793,
|
||||
},
|
||||
{
|
||||
desc: "string with JSON-like content",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: `{"this": "is", "just": "a string"}`,
|
||||
want: `{"this": "is", "just": "a string"}`,
|
||||
},
|
||||
{
|
||||
desc: "whitespace-only string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: " ",
|
||||
want: " ",
|
||||
},
|
||||
// Unknown parameter (no type specified in tools)
|
||||
{
|
||||
desc: "parameter not in tool definition defaults to string",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "some value",
|
||||
want: "some value",
|
||||
},
|
||||
// Union type tests
|
||||
{
|
||||
desc: "string or number union - valid number",
|
||||
paramType: api.PropertyType{"string", "number"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "string or number union - non-numeric string",
|
||||
paramType: api.PropertyType{"string", "number"},
|
||||
raw: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
desc: "number or string union - valid number (order shouldn't matter)",
|
||||
paramType: api.PropertyType{"number", "string"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "integer or null union - valid integer",
|
||||
paramType: api.PropertyType{"integer", "null"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
{
|
||||
desc: "integer or null union - null value",
|
||||
paramType: api.PropertyType{"integer", "null"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "null or integer union - null value (order shouldn't matter)",
|
||||
paramType: api.PropertyType{"null", "integer"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "boolean or string union - valid boolean",
|
||||
paramType: api.PropertyType{"boolean", "string"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean or string union - non-boolean becomes string",
|
||||
paramType: api.PropertyType{"boolean", "string"},
|
||||
raw: "yes",
|
||||
want: "yes",
|
||||
},
|
||||
{
|
||||
desc: "string or boolean union - valid boolean (precedence test)",
|
||||
paramType: api.PropertyType{"string", "boolean"},
|
||||
raw: "false",
|
||||
want: false, // Should be boolean, not string "false"
|
||||
},
|
||||
{
|
||||
desc: "integer or number union - integer value",
|
||||
paramType: api.PropertyType{"integer", "number"},
|
||||
raw: "42",
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
desc: "integer or number union - float value",
|
||||
paramType: api.PropertyType{"integer", "number"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "number or integer union - integer value (precedence test)",
|
||||
paramType: api.PropertyType{"number", "integer"},
|
||||
raw: "42",
|
||||
want: 42, // Should try integer first due to precedence
|
||||
},
|
||||
{
|
||||
desc: "array or object union - valid array",
|
||||
paramType: api.PropertyType{"array", "object"},
|
||||
raw: `[1, 2, 3]`,
|
||||
want: []any{float64(1), float64(2), float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "array or object union - valid object",
|
||||
paramType: api.PropertyType{"array", "object"},
|
||||
raw: `{"key": "value"}`,
|
||||
want: map[string]any{"key": "value"},
|
||||
},
|
||||
{
|
||||
desc: "object or array union - valid array (precedence test)",
|
||||
paramType: api.PropertyType{"object", "array"},
|
||||
raw: `[1, 2, 3]`,
|
||||
want: []any{float64(1), float64(2), float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - null",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - boolean",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - number",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "3.14",
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - string",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
desc: "integer string union - integer string becomes integer",
|
||||
paramType: api.PropertyType{"integer", "string"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
{
|
||||
desc: "string integer union - integer string becomes integer (precedence)",
|
||||
paramType: api.PropertyType{"string", "integer"},
|
||||
raw: "123",
|
||||
want: 123, // Integer has higher precedence than string
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := parseValue(tc.raw, tc.paramType)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
desc: "simple example",
|
||||
raw: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name="get_current_temperature">
|
||||
<parameter name="location">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter name="unit">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
// even though quotes aren't expected in these tags, we have these tests to
|
||||
// make sure they're escaped so they don't blow up the xml parser in case
|
||||
// they happen
|
||||
{
|
||||
desc: "names with quotes",
|
||||
raw: `<function="get current temperature">
|
||||
<parameter="location with spaces">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter="unit with spaces">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name=""get current temperature"">
|
||||
<parameter name=""location with spaces"">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter name=""unit with spaces"">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
{
|
||||
desc: "ampersands in parameter values",
|
||||
raw: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco & San Jose
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name="get_current_temperature">
|
||||
<parameter name="location">
|
||||
San Francisco & San Jose
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := transformToXML(tc.raw)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{desc: "no whitespace", s: "abc", want: 0},
|
||||
{desc: "trailing whitespace", s: "abc ", want: 1},
|
||||
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
|
||||
{desc: "only whitespace", s: " \n ", want: 4},
|
||||
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := trailingWhitespaceLen(tc.s)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %d, want %d", got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var (
|
||||
imStartTag = "<|im_start|>"
|
||||
imEndTag = "<|im_end|>"
|
||||
)
|
||||
|
||||
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
|
||||
// This follows the same approach from the reference implementation, which gives
|
||||
// a particular key ordering
|
||||
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for key, value := range m {
|
||||
if handledKeys[key] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if value is a map or array (needs JSON serialization)
|
||||
switch v := value.(type) {
|
||||
case map[string]any, []any:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
// TODO(drifkin): it would be nice to format the JSON here similarly to
|
||||
// python's default json.dumps behavior (spaces after commas and colons).
|
||||
// This would let us be byte-for-byte compatible with the reference
|
||||
// implementation for most common inputs
|
||||
jsonStr := string(jsonBytes)
|
||||
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
// Simple types, convert to string
|
||||
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// filter out system messages and choose the first (if any) to win
|
||||
var systemMessage string
|
||||
var filteredMessages []api.Message
|
||||
for _, message := range messages {
|
||||
if message.Role != "system" {
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
continue
|
||||
}
|
||||
|
||||
if systemMessage == "" {
|
||||
systemMessage = message.Content
|
||||
}
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
|
||||
// if we have tools but no system message, match the reference implementation by providing a default system message
|
||||
if systemMessage == "" {
|
||||
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
|
||||
}
|
||||
|
||||
sb.WriteString(systemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
|
||||
sb.WriteString("<tools>")
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString("<function>\n")
|
||||
sb.WriteString("<name>" + tool.Function.Name + "</name>")
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
|
||||
}
|
||||
sb.WriteString("\n<parameters>")
|
||||
|
||||
for name, prop := range tool.Function.Parameters.Properties {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + name + "</name>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
// TODO(!!!)(drifkin): we should match the reference implementation for
|
||||
// more complex types here instead of using this format
|
||||
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
|
||||
}
|
||||
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("\n<description>" + prop.Description + "</description>")
|
||||
}
|
||||
|
||||
// Render any additional keys not already handled
|
||||
handledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"description": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
|
||||
|
||||
sb.WriteString("\n</parameter>")
|
||||
}
|
||||
|
||||
// Render extra keys for parameters (everything except 'type' and 'properties')
|
||||
paramHandledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"properties": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
|
||||
|
||||
sb.WriteString("\n</parameters>")
|
||||
sb.WriteString("\n</function>")
|
||||
}
|
||||
sb.WriteString("\n</tools>")
|
||||
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
||||
}
|
||||
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content + "\n")
|
||||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||
for name, value := range toolCall.Function.Arguments {
|
||||
valueStr := formatToolCallArgument(value)
|
||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||
}
|
||||
sb.WriteString("\n</function>\n</tool_call>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
sb.WriteString(message.Content)
|
||||
if !prefill {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
// consecutive tool responses should share a single `<im_start>user`, but
|
||||
// have their own <tool_response> tags
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
default:
|
||||
sb.WriteString(imStartTag + message.Role + "\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func formatToolCallArgument(value any) string {
|
||||
if value == nil {
|
||||
return "null"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
}
|
||||
|
||||
if reflect.TypeOf(value) != nil {
|
||||
kind := reflect.TypeOf(value).Kind()
|
||||
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
|
||||
if marshalled, err := json.Marshal(value); err == nil {
|
||||
return string(marshalled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3CoderRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "with tools and response",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in San Francisco?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check the weather in San Francisco for you.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
|
||||
{Role: "user", Content: "That sounds nice! What about New York?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Required: []string{"unit"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||
// TODO(drifkin): add multiple params back once we have predictable
|
||||
// order via some sort of ordered map type (see
|
||||
// <https://github.com/ollama/ollama/issues/12244>)
|
||||
/*
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||
*/
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>get_weather</name>
|
||||
<description>Get the current weather in a given location</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>The unit of temperature</description>
|
||||
<enum>["celsius","fahrenheit"]</enum>
|
||||
</parameter>
|
||||
<required>["unit"]</required>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
What is the weather like in San Francisco?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll check the weather in San Francisco for you.
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||
}}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||
}}}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>double</name>
|
||||
<description>Double a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to double</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
<function>
|
||||
<name>triple</name>
|
||||
<description>Triple a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to triple</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
call double(1) and triple(2)<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll call double(1) and triple(2) for you.
|
||||
|
||||
<tool_call>
|
||||
<function=double>
|
||||
<parameter=number>
|
||||
1
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=triple>
|
||||
<parameter=number>
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"number": 2}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "prefill",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me something interesting."},
|
||||
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Tell me something interesting.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll tell you something interesting about cats`,
|
||||
},
|
||||
{
|
||||
name: "complex tool call arguments should remain json encoded",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: map[string]any{
|
||||
"payload": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||
},
|
||||
expected: `<|im_start|>user
|
||||
call tool<|im_end|>
|
||||
<|im_start|>assistant
|
||||
|
||||
<tool_call>
|
||||
<function=echo>
|
||||
<parameter=payload>
|
||||
{"foo":"bar"}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallArgument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
arg: "foo",
|
||||
// notice no quotes around the string
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
arg: map[string]any{"foo": "bar"},
|
||||
expected: "{\"foo\":\"bar\"}",
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
arg: 1,
|
||||
expected: "1",
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
arg: true,
|
||||
expected: "true",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := formatToolCallArgument(tt.arg)
|
||||
if got != tt.expected {
|
||||
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
|
||||
|
||||
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
renderer := rendererForName(name)
|
||||
if renderer == nil {
|
||||
return "", fmt.Errorf("unknown renderer %q", name)
|
||||
}
|
||||
return renderer(msgs, tools, think)
|
||||
}
|
||||
|
||||
func rendererForName(name string) rendererFunc {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
return Qwen3CoderRenderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
@@ -12,19 +13,19 @@ import (
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePiece struct {
|
||||
type SentencePieceModel struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
@@ -38,21 +39,21 @@ func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePiece{
|
||||
return SentencePieceModel{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
@@ -181,11 +182,12 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -218,7 +220,7 @@ func (q *queue) Pop() interface{} {
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
@@ -244,6 +246,6 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user