mirror of
https://github.com/ollama/ollama.git
synced 2026-01-07 15:10:51 -05:00
Compare commits
31 Commits
v0.11.9
...
pdevine/pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c10a40db99 | ||
|
|
93c64ea1b1 | ||
|
|
3f6642f6fc | ||
|
|
6f7117145f | ||
|
|
92b96d54ef | ||
|
|
9d56e63dbf | ||
|
|
053092185e | ||
|
|
44a6792873 | ||
|
|
e4ce68311a | ||
|
|
26214125e8 | ||
|
|
61fb912ca4 | ||
|
|
aba1575315 | ||
|
|
eb10390de9 | ||
|
|
feb18cd710 | ||
|
|
8a7e2055d2 | ||
|
|
29ddfc2cab | ||
|
|
71cb86af3e | ||
|
|
5198956372 | ||
|
|
17a023f34b | ||
|
|
8d6fffaead | ||
|
|
20b53eaa72 | ||
|
|
6745182885 | ||
|
|
f810ec741c | ||
|
|
e119783e66 | ||
|
|
1a558f98e2 | ||
|
|
7b91c9ce51 | ||
|
|
950d33aa30 | ||
|
|
9714e38dd0 | ||
|
|
4378ae4ffa | ||
|
|
5994e8e8fd | ||
|
|
b3e6120736 |
28
.github/workflows/release.yaml
vendored
28
.github/workflows/release.yaml
vendored
@@ -65,14 +65,36 @@ jobs:
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v12'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runner_dir: 'cuda_v13'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runner_dir: ''
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -96,7 +118,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
@@ -138,7 +160,7 @@ jobs:
|
||||
run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
env:
|
||||
@@ -232,7 +254,7 @@ jobs:
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
16
.github/workflows/test.yaml
vendored
16
.github/workflows/test.yaml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
@@ -78,8 +78,17 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
@@ -102,7 +111,8 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
|
||||
@@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
@@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
|
||||
@@ -18,6 +18,14 @@
|
||||
"name": "CUDA",
|
||||
"inherits": [ "Default" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
@@ -26,6 +34,14 @@
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
@@ -72,11 +88,21 @@
|
||||
"configurePreset": "CUDA",
|
||||
"targets": [ "ggml-cuda" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 11"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 13"
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
30
Dockerfile
30
Dockerfile
@@ -39,15 +39,35 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||
&& cmake --build --parallel --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
@@ -92,10 +112,14 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||
|
||||
|
||||
@@ -413,6 +413,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -388,8 +388,12 @@ type EmbedRequest struct {
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Truncate truncates the input to fit the model's max sequence length.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Dimensions truncates the output embedding to the specified dimension.
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
@@ -56,10 +56,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type bertModel struct {
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
normalizeEmbeddings bool
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
break
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
|
||||
@@ -43,14 +43,15 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return "sbsa"
|
||||
}
|
||||
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
// The detected driver is older than Feb 2023
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
return "v11"
|
||||
if gpuInfo.DriverMajor < 13 {
|
||||
// The detected driver is older than 580 (Aug 2025)
|
||||
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
}
|
||||
return "v12"
|
||||
}
|
||||
return "v12"
|
||||
return "v13"
|
||||
}
|
||||
|
||||
@@ -1708,6 +1708,7 @@ Advanced parameters:
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
### Examples
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||
|
||||
|
||||
## macOS (Apple Silicon)
|
||||
|
||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||
|
||||
@@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||
sudo rm -rf /usr/lib/ollama
|
||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||
```
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an
|
||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||
|
||||
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
|
||||
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
|
||||
|
||||
|
||||
## AMD GPU Discovery
|
||||
|
||||
|
||||
@@ -185,8 +185,6 @@ var (
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable the new memory estimation logic
|
||||
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -272,7 +270,6 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
149
fs/ggml/ggml.go
149
fs/ggml/ggml.go
@@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() []uint64 {
|
||||
headCountDefault := uint32(1)
|
||||
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
|
||||
if len(headCount) == 1 {
|
||||
headCountDefault = headCount[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCount) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCount) {
|
||||
out[i] = uint64(headCountDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCount[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
@@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() []uint64 {
|
||||
headCountKVDefault := uint32(1)
|
||||
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
|
||||
if len(headCountKV) == 1 {
|
||||
headCountKVDefault = headCountKV[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCountKV) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCountKV) {
|
||||
out[i] = uint64(headCountKVDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCountKV[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
@@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string {
|
||||
return kv.String("tokenizer.chat_template")
|
||||
}
|
||||
|
||||
// ssm architecture parameters
|
||||
|
||||
func (kv KV) SSMConvKernel() uint64 {
|
||||
return uint64(kv.Uint("ssm.conv_kernel"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMInnerSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.inner_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMStateSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.state_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
@@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
|
||||
return slices.Min(arrVal), slices.Max(arrVal)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return u32, u32
|
||||
return []uint32{u32}
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
return u32s.values
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
dst := make([]uint32, len(i32s.values))
|
||||
for i, v := range i32s.values {
|
||||
if v < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
|
||||
}
|
||||
dst[i] = uint32(v)
|
||||
}
|
||||
return uint32(min), uint32(max)
|
||||
return dst
|
||||
}
|
||||
|
||||
return defaultValue, defaultValue
|
||||
return []uint32{defaultValue}
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
@@ -486,7 +550,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsArr := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
headsKVArr := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
@@ -496,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
|
||||
// Default for models unless special-cased below. These defaults mirror the
|
||||
// cache usage in llama.cpp under the assumption that models without special
|
||||
// cases below will use the llamarunner and caching will be handled by the
|
||||
// llama.cpp layer.
|
||||
//
|
||||
// This also assumes that a layer without heads or headsKV set is recurrent
|
||||
// which is usually the case. Some models (eg nemotronh) use "blocks" in
|
||||
// place of layers where some are MLP blocks that don't have any cache.
|
||||
// Models like this will need a special case below to be accurately
|
||||
// estimated.
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
kvSizeAttn := uint64(0)
|
||||
kvSizeRecurrent := uint64(0)
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
headsL := headsArr[i]
|
||||
headsKVL := headsKVArr[i]
|
||||
if headsL > 0 && headsKVL > 0 {
|
||||
// full attention layer
|
||||
// NOTE: Assumes uniform values for all attn layers
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
|
||||
kvSizeAttn += kv[i]
|
||||
} else {
|
||||
// recurrent layer
|
||||
ssmDConv := f.KV().SSMConvKernel()
|
||||
ssmDState := f.KV().SSMStateSize()
|
||||
ssmDInner := f.KV().SSMInnerSize()
|
||||
ssmNGroups := f.KV().SSMGroupCount()
|
||||
nEmbdR := uint64(0)
|
||||
if ssmDConv > 0 {
|
||||
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
|
||||
}
|
||||
nEmbdS := ssmDState * ssmDInner
|
||||
|
||||
// recurrent always uses F32 in llama.cpp backend
|
||||
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
|
||||
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
|
||||
|
||||
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
|
||||
kvSizeRecurrent += kv[i]
|
||||
}
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama", "llama4":
|
||||
@@ -759,12 +864,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" {
|
||||
return true
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
||||
return cacheType == "f16"
|
||||
slog.Warn("model only supports non-quantized cache types", "model", arch)
|
||||
return false
|
||||
}
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
@@ -774,6 +883,10 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
@@ -794,6 +907,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
case "f32":
|
||||
return 4 // f32 (default for recurrent)
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
|
||||
@@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "qwen3:0.6b"
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
chooseModels:
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
|
||||
@@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
@@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for generate")
|
||||
verify()
|
||||
}
|
||||
return context
|
||||
}
|
||||
@@ -552,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
@@ -570,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[][]string{
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
verify()
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for chat")
|
||||
verify()
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
||||
}
|
||||
nChunks := C.mtmd_input_chunks_size(ic)
|
||||
numEmbed := llamaContext.Model().NEmbd()
|
||||
lastChunkSize := 0
|
||||
embed := make([][]float32, 0)
|
||||
for i := range int(nChunks) {
|
||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||
lastChunkSize = numTokens
|
||||
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||
|
||||
// Encode the chunk
|
||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||
return nil, errors.New("unable to encode mtmd image chunk")
|
||||
}
|
||||
}
|
||||
|
||||
// Get the embeddings
|
||||
embed := make([][]float32, lastChunkSize)
|
||||
embd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == embd {
|
||||
return nil, errors.New("failed to get image embedding")
|
||||
}
|
||||
// Get the embeddings for this chunk
|
||||
chunkEmbed := make([][]float32, numTokens)
|
||||
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == chunkEmbd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range lastChunkSize {
|
||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range numTokens {
|
||||
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
embed = append(embed, chunkEmbed...)
|
||||
}
|
||||
|
||||
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
if f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,7 +148,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
var textProcessor model.TextProcessor
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
@@ -161,11 +165,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
}
|
||||
}
|
||||
|
||||
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
|
||||
if newEstimates {
|
||||
slog.Info("enabling new memory estimates")
|
||||
}
|
||||
|
||||
// Verify the requested context size is <= the model training size
|
||||
trainCtx := f.KV().ContextLength()
|
||||
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
|
||||
@@ -173,6 +172,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
opts.NumCtx = int(trainCtx)
|
||||
}
|
||||
|
||||
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
|
||||
|
||||
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
||||
|
||||
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
||||
@@ -218,7 +219,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if kvct != "" && f.SupportsKVCacheType(kvct) {
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
@@ -431,7 +432,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
}
|
||||
}()
|
||||
|
||||
if newEstimates {
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
|
||||
@@ -416,6 +416,7 @@ type Tensor interface {
|
||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
|
||||
36
ml/nn/pooling/pooling.go
Normal file
36
ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
TypeRank
|
||||
|
||||
TypeUnknown = 0xFFFFFFFE
|
||||
TypeUnspecified = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
|
||||
switch poolingType {
|
||||
case TypeNone:
|
||||
return hiddenStates
|
||||
case TypeMean:
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
case TypeCLS:
|
||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||
case TypeLast:
|
||||
panic("not implemented")
|
||||
case TypeRank:
|
||||
panic("not implemented")
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
}
|
||||
@@ -201,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -54,10 +54,9 @@ type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
@@ -66,7 +65,8 @@ type Batch struct {
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -23,7 +24,11 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
@@ -103,6 +108,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
@@ -237,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
vv = reflect.Indirect(vv)
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
||||
181
model/models/bert/model.go
Normal file
181
model/models/bert/model.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package bert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||
|
||||
Layers []EncoderLayer `gguf:"blk"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
|
||||
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
// Attention
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// MLP
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if a.QueryNorm != nil {
|
||||
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
}
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if a.KeyNorm != nil {
|
||||
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
}
|
||||
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength int
|
||||
poolingType pooling.Type
|
||||
eps float32
|
||||
normalize bool
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
//nolint:misspell
|
||||
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||
// support it for compatibility.
|
||||
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("bert", New)
|
||||
model.Register("bert_embed", New)
|
||||
}
|
||||
@@ -24,7 +24,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -40,7 +40,7 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
|
||||
62
model/models/gemma3/embed.go
Normal file
62
model/models/gemma3/embed.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
|
||||
@@ -159,8 +159,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
@@ -191,12 +193,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
|
||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if i == len(m.TransformerBlocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
|
||||
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
|
||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
|
||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -12,18 +12,18 @@ import (
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePieceModel struct {
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePieceModel{
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
@@ -181,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -219,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(&v)
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePieceModel(vocab)
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS)
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||
ids = append([]int32{v.BOS[0]}, ids...)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS)
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||
ids = append(ids, v.EOS[0])
|
||||
}
|
||||
|
||||
|
||||
167
model/wordpiece.go
Normal file
167
model/wordpiece.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
// this differs from original word piece which uses "##" to indicate subwords.
|
||||
const ggmlPrefix = "▁"
|
||||
|
||||
var wordPieceReplacer = strings.NewReplacer(
|
||||
" .", ".",
|
||||
" ?", "?",
|
||||
" !", "!",
|
||||
" ,", ",",
|
||||
" ' ", "'",
|
||||
" n't", "n't",
|
||||
" 'm", "'m",
|
||||
" do not", " don't",
|
||||
" 's", "'s",
|
||||
" 've", "'ve",
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
return "", fmt.Errorf("invalid token id: %d", id)
|
||||
}
|
||||
|
||||
var separator string
|
||||
piece := wpm.vocab.Values[id]
|
||||
if i > 0 &&
|
||||
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||
separator = " "
|
||||
}
|
||||
|
||||
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF,
|
||||
r >= 0x3400 && r <= 0x4DBF,
|
||||
r >= 0x20000 && r <= 0x2A6DF,
|
||||
r >= 0x2A700 && r <= 0x2B73F,
|
||||
r >= 0x2B740 && r <= 0x2B81F,
|
||||
r >= 0x2B820 && r <= 0x2CEAF,
|
||||
r >= 0xF900 && r <= 0xFAFF,
|
||||
r >= 0x2F800 && r <= 0x2FA1F:
|
||||
runes = append(runes, ' ', r, ' ')
|
||||
default:
|
||||
runes = append(runes, r)
|
||||
}
|
||||
}
|
||||
|
||||
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||
// split on but keep punctuation
|
||||
var start int
|
||||
for start < len(w) {
|
||||
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||
if end < 0 {
|
||||
end = len(w) - start
|
||||
} else if end == 0 {
|
||||
end = 1
|
||||
}
|
||||
|
||||
if !yield(w[start : start+end]) {
|
||||
return
|
||||
}
|
||||
|
||||
start += end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
unk := wpm.vocab.Encode("[UNK]")
|
||||
for word := range wpm.words(s) {
|
||||
var start int
|
||||
var pieces []int32
|
||||
for start < len(word) {
|
||||
end := len(word)
|
||||
|
||||
var piece int32
|
||||
for start < end {
|
||||
subword := word[start:end]
|
||||
if start == 0 {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
// TODO: some models might not want [ToLower]
|
||||
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
end--
|
||||
}
|
||||
|
||||
if piece < 0 {
|
||||
// Unknown token
|
||||
pieces = pieces[:0]
|
||||
break
|
||||
}
|
||||
|
||||
pieces = append(pieces, piece)
|
||||
start = end
|
||||
}
|
||||
|
||||
if len(pieces) > 0 {
|
||||
ids = append(ids, pieces...)
|
||||
} else {
|
||||
ids = append(ids, unk)
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
51
model/wordpiece_test.go
Normal file
51
model/wordpiece_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
})
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -76,8 +76,9 @@ type JsonSchema struct {
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
@@ -1005,7 +1006,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
115
parser/parser.go
115
parser/parser.go
@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
for _, c := range f.Commands {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
name := c.Args.(string)
|
||||
path, err := expandPath(name, relativeDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestMap, err := fileDigestMap(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
req.From = c.Args
|
||||
req.From = name
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
}
|
||||
}
|
||||
case "adapter":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
adapter := c.Args.(string)
|
||||
path, err := expandPath(adapter, relativeDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,21 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
|
||||
req.Adapters = digestMap
|
||||
case "template":
|
||||
req.Template = c.Args
|
||||
template := c.Args.(string)
|
||||
req.Template = template
|
||||
case "system":
|
||||
req.System = c.Args
|
||||
system := c.Args.(string)
|
||||
req.System = system
|
||||
case "license":
|
||||
licenses = append(licenses, c.Args)
|
||||
license := c.Args.(string)
|
||||
licenses = append(licenses, license)
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
default:
|
||||
msg := c.Args.(*Message)
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
case "parameter":
|
||||
if slices.Contains(deprecatedParameters, c.Name) {
|
||||
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
|
||||
fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
|
||||
break
|
||||
}
|
||||
|
||||
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
||||
param := c.Args.(*Parameter)
|
||||
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -123,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,7 +254,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
for _, match := range matches {
|
||||
if ct, err := detectContentType(match); err != nil {
|
||||
return nil, err
|
||||
} else if ct != contentType {
|
||||
} else if len(contentType) > 0 && ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||
}
|
||||
}
|
||||
@@ -255,7 +263,8 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
@@ -311,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
type Command struct {
|
||||
Name string
|
||||
Args string
|
||||
Args any
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func (c Command) String() string {
|
||||
@@ -320,12 +339,16 @@ func (c Command) String() string {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
data := c.Args.(string)
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
|
||||
data := c.Args.(*Message)
|
||||
fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
|
||||
case "parameter":
|
||||
data := c.Args.(*Parameter)
|
||||
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
|
||||
default:
|
||||
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
|
||||
fmt.Printf("unknown command '%s'\n", c.Name)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
@@ -365,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
var curr state
|
||||
var currLine int = 1
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
var f Modelfile
|
||||
|
||||
@@ -412,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
case "parameter":
|
||||
// transition to stateParameter which sets command name
|
||||
next = stateParameter
|
||||
cmd.Name = s
|
||||
case "message":
|
||||
// transition to stateMessage which validates the message role
|
||||
next = stateMessage
|
||||
@@ -420,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
cmd.Name = s
|
||||
}
|
||||
case stateParameter:
|
||||
cmd.Name = b.String()
|
||||
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||
if !ok || isSpace(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
cmd.Args = &Parameter{
|
||||
Name: s,
|
||||
}
|
||||
case stateMessage:
|
||||
if !isValidMessageRole(b.String()) {
|
||||
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||
if !ok || isSpace(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !isValidMessageRole(s) {
|
||||
return nil, &ParserError{
|
||||
LineNumber: currLine,
|
||||
Msg: errInvalidMessageRole.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
role = b.String()
|
||||
cmd.Args = &Message{
|
||||
Role: s,
|
||||
}
|
||||
case stateComment, stateNil:
|
||||
// pass
|
||||
case stateValue:
|
||||
@@ -442,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
role = ""
|
||||
switch cmd.Name {
|
||||
case "parameter":
|
||||
p := cmd.Args.(*Parameter)
|
||||
p.Value = s
|
||||
case "message":
|
||||
m := cmd.Args.(*Message)
|
||||
m.Content = s
|
||||
default:
|
||||
cmd.Args = s
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
f.Commands = append(f.Commands, cmd)
|
||||
}
|
||||
|
||||
@@ -472,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
switch cmd.Name {
|
||||
case "parameter":
|
||||
c := cmd.Args.(*Parameter)
|
||||
c.Value = s
|
||||
case "message":
|
||||
c := cmd.Args.(*Message)
|
||||
c.Content = s
|
||||
default:
|
||||
cmd.Args = s
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
f.Commands = append(f.Commands, cmd)
|
||||
default:
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
{Name: "model", Args: "model1"},
|
||||
{Name: "adapter", Args: "adapter1"},
|
||||
{Name: "license", Args: "MIT"},
|
||||
{Name: "param1", Args: "value1"},
|
||||
{Name: "param2", Args: "value2"},
|
||||
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
||||
}
|
||||
|
||||
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
{Name: "model", Args: " model 1"},
|
||||
{Name: "adapter", Args: "adapter3"},
|
||||
{Name: "license", Args: "MIT "},
|
||||
{Name: "param1", Args: "value1"},
|
||||
{Name: "param2", Args: "value2"},
|
||||
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
||||
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
|
||||
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"PARAMETER param1 value1\nFROM foo",
|
||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
||||
[]Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"PARAMETER what the \nFROM lemons make lemonade ",
|
||||
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
|
||||
[]Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
@@ -211,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
@@ -221,7 +221,7 @@ FROM foo
|
||||
MESSAGE system You are a file parser. Always parse things.`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
@@ -234,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||
{Name: "message", Args: &Message{"user", "Hey there!"}},
|
||||
{Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
@@ -244,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system """
|
||||
You are a multiline file parser. Always parse things.
|
||||
You are a multiline file "parser". Always parse things.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
|
||||
{Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
@@ -514,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
|
||||
|
||||
assert.Equal(t, []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: v.name, Args: v.value},
|
||||
{Name: "parameter", Args: &Parameter{v.name, v.value}},
|
||||
}, modelfile.Commands)
|
||||
})
|
||||
}
|
||||
@@ -617,8 +617,8 @@ SYSTEM You are a utf16 file.
|
||||
|
||||
expected := []Command{
|
||||
{Name: "model", Args: "bob"},
|
||||
{Name: "param1", Args: "1"},
|
||||
{Name: "param2", Args: "4096"},
|
||||
{Name: "parameter", Args: &Parameter{"param1", "1"}},
|
||||
{Name: "parameter", Args: &Parameter{"param2", "4096"}},
|
||||
{Name: "system", Args: "You are a utf16 file."},
|
||||
}
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ type InputCache struct {
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||
numCtx := kvSize / int32(numSlots)
|
||||
|
||||
if numCtx < 1 {
|
||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||
if int(numCtx) < batchSize {
|
||||
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
|
||||
}
|
||||
|
||||
slots := make([]InputCacheSlot, numSlots)
|
||||
@@ -70,11 +70,9 @@ func kvCacheTypeFromStr(s string) ml.DType {
|
||||
}
|
||||
|
||||
func (c *InputCache) Close() {
|
||||
if c == nil {
|
||||
return
|
||||
if c != nil && c.cache != nil {
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
@@ -95,7 +93,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -113,6 +111,10 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -405,6 +405,8 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
select {
|
||||
@@ -418,7 +420,12 @@ func (s *Server) run(ctx context.Context) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go s.computeBatch(activeBatch)
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -429,12 +436,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if pendingBatch.ctx != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
<-pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||
nextBatch.inputsReadyCh <- struct{}{}
|
||||
@@ -460,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batchOutputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -542,11 +550,11 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
seq.iBatch = len(batchOutputs)
|
||||
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -560,7 +568,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
return
|
||||
@@ -569,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
@@ -589,14 +598,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
defer activeBatch.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
<-activeBatch.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
activeBatch.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
@@ -626,7 +635,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != activeBatch.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -666,18 +675,19 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||
activeBatch.ctx.ComputeWithNotify(
|
||||
func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
activeBatch.computeStartedCh <- struct{}{}
|
||||
},
|
||||
activeBatch.modelOutput)
|
||||
logits := activeBatch.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
@@ -689,16 +699,15 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
|
||||
seq.embedding <- outputs
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
@@ -711,7 +720,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
@@ -834,7 +843,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
@@ -890,6 +899,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embedding request due to client closing the connection")
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
@@ -978,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
@@ -1017,9 +1083,13 @@ func (s *Server) allocModel(
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
panicErr = noMem
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
@@ -1206,10 +1276,7 @@ func Execute(args []string) error {
|
||||
mux := http.NewServeMux()
|
||||
// TODO: support embeddings
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /embedding", server.embeddings)
|
||||
mux.HandleFunc("POST /completion", server.completion)
|
||||
mux.HandleFunc("GET /health", server.health)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ function checkEnv() {
|
||||
}
|
||||
|
||||
|
||||
function buildOllama() {
|
||||
function buildCPU() {
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
||||
@@ -90,20 +90,72 @@ function buildOllama() {
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component CPU --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
|
||||
function buildCUDA11() {
|
||||
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
|
||||
# 19.40 is the last compiler version that works, but recent udpates are 19.43
|
||||
# So this pins to MSVC 2019 for best compatibility
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
$hashEnv = @{}
|
||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||
if ("$script:CUDA_DIRS".Contains("v12")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
||||
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
||||
write-host "Building CUDA v12 backend libraries"
|
||||
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
|
||||
if ("$script:CUDA_DIRS".Contains("v11")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
write-host "Building CUDA v11 backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildCUDA12() {
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
$hashEnv = @{}
|
||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||
if ("$script:CUDA_DIRS".Contains("v12.8")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
write-host "Building CUDA v12 backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildCUDA13() {
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
$hashEnv = @{}
|
||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||
if ("$script:CUDA_DIRS".Contains("v13")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
write-host "Building CUDA v13 backend libraries $cuda"
|
||||
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildROCm() {
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
if ($env:HIP_PATH) {
|
||||
write-host "Building ROCm backend libraries"
|
||||
if (-Not (get-command -ErrorAction silent ninja)) {
|
||||
@@ -129,6 +181,10 @@ function buildOllama() {
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildOllama() {
|
||||
mkdir -Force -path "${script:DIST_DIR}\"
|
||||
write-host "Building ollama CLI"
|
||||
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
@@ -236,6 +292,10 @@ function distZip() {
|
||||
checkEnv
|
||||
try {
|
||||
if ($($args.count) -eq 0) {
|
||||
buildCPU
|
||||
buildCUDA12
|
||||
buildCUDA13
|
||||
buildROCm
|
||||
buildOllama
|
||||
buildApp
|
||||
gatherDependencies
|
||||
|
||||
@@ -488,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
@@ -555,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -576,7 +584,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embeddings[i] = normalize(embedding)
|
||||
// TODO: this first normalization should be done by the model
|
||||
embedding = normalize(embedding)
|
||||
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
|
||||
embedding = normalize(embedding[:req.Dimensions])
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -602,11 +615,7 @@ func normalize(vec []float32) []float32 {
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
norm := float32(0.0)
|
||||
if sum > 0 {
|
||||
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
||||
}
|
||||
|
||||
norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
|
||||
for i := range vec {
|
||||
vec[i] *= norm
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user