mirror of
https://github.com/ollama/ollama.git
synced 2026-01-13 18:09:59 -05:00
Compare commits
46 Commits
grace/deep
...
llama-upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
961fdb69d5 | ||
|
|
9aff392e6c | ||
|
|
d3c1524fa4 | ||
|
|
7cc2a653f2 | ||
|
|
2584940016 | ||
|
|
c6d4c0c7f2 | ||
|
|
1ef4241727 | ||
|
|
68fafd3002 | ||
|
|
2b2cda7a2b | ||
|
|
3cfe9fe146 | ||
|
|
a23b559b4c | ||
|
|
33ee7168ba | ||
|
|
34d0c55ea5 | ||
|
|
53a5a9e9ae | ||
|
|
e30e08a7d6 | ||
|
|
12e2b3514a | ||
|
|
626af2d809 | ||
|
|
76912c062a | ||
|
|
6c3faafed2 | ||
|
|
e51dead636 | ||
|
|
d087e46bd1 | ||
|
|
37f6f3af24 | ||
|
|
e1bdc23dd2 | ||
|
|
2e78653ff9 | ||
|
|
f5f74e12c1 | ||
|
|
18fdcc94e5 | ||
|
|
7ad036992f | ||
|
|
172b5924af | ||
|
|
8852220f59 | ||
|
|
7325791599 | ||
|
|
522c11a763 | ||
|
|
0fadeffaee | ||
|
|
49a9c9ba6a | ||
|
|
1c094038bc | ||
|
|
a013693f80 | ||
|
|
f6a016f49d | ||
|
|
45c4739374 | ||
|
|
2dd029de12 | ||
|
|
903b1fc97f | ||
|
|
89eb795293 | ||
|
|
7e3ea813c1 | ||
|
|
7b95087b9d | ||
|
|
971d62595a | ||
|
|
ffbe8e076d | ||
|
|
2c639431b1 | ||
|
|
aacd1cb394 |
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
@@ -68,6 +68,7 @@ jobs:
|
||||
name: bundles-darwin
|
||||
path: |
|
||||
dist/*.tgz
|
||||
dist/*.tar.zst
|
||||
dist/*.zip
|
||||
dist/*.dmg
|
||||
|
||||
@@ -392,13 +393,13 @@ jobs:
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
*.tgz
|
||||
*.tar.zst
|
||||
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
@@ -531,7 +532,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||
# target architecture.
|
||||
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||
# Single architecture specified
|
||||
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -54,6 +70,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
# Define GGML version variables for shared library SOVERSION
|
||||
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 0)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||
@@ -140,14 +163,48 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,6 +83,28 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -140,6 +162,21 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
33
Dockerfile
33
Dockerfile
@@ -131,8 +131,36 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -153,6 +181,7 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -171,7 +200,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=17f7f4baad8b3a716ee139da7bb56ae984e8c0fa
|
||||
FETCH_HEAD=b1377188784f9aea26b8abde56d4aee8c733eec7
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
|
||||
778
anthropic/anthropic.go
Normal file
778
anthropic/anthropic.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
953
anthropic/anthropic_test.go
Normal file
953
anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
162
api/types.go
162
api/types.go
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/orderedmap"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -227,13 +229,79 @@ type ToolCallFunction struct {
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallFunctionArguments map[string]any
|
||||
// ToolCallFunctionArguments holds tool call arguments in insertion order.
|
||||
type ToolCallFunctionArguments struct {
|
||||
om *orderedmap.Map[string, any]
|
||||
}
|
||||
|
||||
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
|
||||
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
|
||||
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return nil, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair, preserving insertion order.
|
||||
func (t *ToolCallFunctionArguments) Set(key string, value any) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of arguments.
|
||||
func (t *ToolCallFunctionArguments) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, any) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
|
||||
func (t *ToolCallFunctionArguments) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
if t == nil || t.om == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t.om)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
@@ -282,12 +350,78 @@ func (pt PropertyType) String() string {
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
// ToolPropertiesMap holds tool properties in insertion order.
|
||||
type ToolPropertiesMap struct {
|
||||
om *orderedmap.Map[string, ToolProperty]
|
||||
}
|
||||
|
||||
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
|
||||
func NewToolPropertiesMap() *ToolPropertiesMap {
|
||||
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
|
||||
}
|
||||
|
||||
// Get retrieves a property by name.
|
||||
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return ToolProperty{}, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a property, preserving insertion order.
|
||||
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of properties.
|
||||
func (t *ToolPropertiesMap) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all properties in insertion order.
|
||||
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, ToolProperty) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
|
||||
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
@@ -336,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string {
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
@@ -553,6 +687,9 @@ type CreateRequest struct {
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -603,6 +740,7 @@ type ShowResponse struct {
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
@@ -11,6 +11,24 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
|
||||
props := NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) ToolCallFunctionArguments {
|
||||
args := NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
@@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
name: "no required",
|
||||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
@@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||
fn := ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
||||
Arguments: testArgs(map[string]any{"message": "hi"}),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fn)
|
||||
@@ -504,6 +522,116 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected ToolProperty
|
||||
}{
|
||||
{
|
||||
name: "nested object properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Location details",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "Street address"
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location details",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"address": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "Street address",
|
||||
},
|
||||
"city": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Event",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "object",
|
||||
"description": "Location",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "object",
|
||||
"description": "GPS coordinates",
|
||||
"properties": {
|
||||
"lat": {"type": "number", "description": "Latitude"},
|
||||
"lng": {"type": "number", "description": "Longitude"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Event",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"location": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"coordinates": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "GPS coordinates",
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||
}),
|
||||
},
|
||||
}),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var prop ToolProperty
|
||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare JSON representations since pointer comparison doesn't work
|
||||
expectedJSON, err := json.Marshal(tt.expected)
|
||||
require.NoError(t, err)
|
||||
actualJSON, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||
|
||||
// Round-trip test: marshal and unmarshal again
|
||||
data, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
|
||||
var prop2 ToolProperty
|
||||
err = json.Unmarshal(data, &prop2)
|
||||
require.NoError(t, err)
|
||||
|
||||
prop2JSON, err := json.Marshal(prop2)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -515,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
@@ -537,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
Properties: testPropsMap(map[string]ToolProperty{}),
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
@@ -550,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("zebra", "z")
|
||||
args.Set("apple", "a")
|
||||
args.Set("mango", "m")
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":1,"a":2,"m":3,"b":4}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(original), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("String method returns ordered JSON", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("c", 3)
|
||||
args.Set("a", 1)
|
||||
args.Set("b", 2)
|
||||
|
||||
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("key1", "value1")
|
||||
args.Set("key2", 42)
|
||||
|
||||
v, ok := args.Get("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", v)
|
||||
|
||||
v, ok = args.Get("key2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, v)
|
||||
|
||||
_, ok = args.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
assert.Equal(t, 0, args.Len())
|
||||
|
||||
args.Set("a", 1)
|
||||
assert.Equal(t, 1, args.Len())
|
||||
|
||||
args.Set("b", 2)
|
||||
assert.Equal(t, 2, args.Len())
|
||||
})
|
||||
|
||||
t.Run("empty args marshal to empty object", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("zero value args marshal to empty object", func(t *testing.T) {
|
||||
var args ToolCallFunctionArguments
|
||||
assert.Equal(t, "{}", args.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
|
||||
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(jsonData), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(original), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
|
||||
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
|
||||
|
||||
v, ok := props.Get("name")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The name", v.Description)
|
||||
|
||||
v, ok = props.Get("age")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The age", v.Description)
|
||||
|
||||
_, ok = props.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
assert.Equal(t, 0, props.Len())
|
||||
|
||||
props.Set("a", ToolProperty{})
|
||||
assert.Equal(t, 1, props.Len())
|
||||
|
||||
props.Set("b", ToolProperty{})
|
||||
assert.Equal(t, 2, props.Len())
|
||||
})
|
||||
|
||||
t.Run("nil props marshal to null", func(t *testing.T) {
|
||||
var props *ToolPropertiesMap
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `null`, string(data))
|
||||
})
|
||||
|
||||
t.Run("ToMap returns regular map", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
m := props.ToMap()
|
||||
assert.Equal(t, 2, len(m))
|
||||
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
|
||||
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
|
||||
t.Run("nested objects preserve order", func(t *testing.T) {
|
||||
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Outer keys should be in order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"outer", "simple"}, keys)
|
||||
})
|
||||
|
||||
t.Run("arrays as values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("items", []string{"a", "b", "c"})
|
||||
args.Set("numbers", []int{1, 2, 3})
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
|
||||
t.Run("nested properties preserve order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
|
||||
nestedProps := NewToolPropertiesMap()
|
||||
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
|
||||
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
props.Set("outer", ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Properties: nestedProps,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both outer and inner should preserve order
|
||||
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
|
||||
"c",
|
||||
"cpp",
|
||||
"sql",
|
||||
"swift",
|
||||
"yaml",
|
||||
"markdown",
|
||||
],
|
||||
|
||||
@@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
for _, toolCall := range res.Message.ToolCalls {
|
||||
// continues loop as tools were executed
|
||||
toolsExecuted = true
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
|
||||
if err != nil {
|
||||
errContent := fmt.Sprintf("Error: %v", err)
|
||||
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
||||
@@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||
|
||||
tool.Function.Parameters.Type = "object"
|
||||
tool.Function.Parameters.Required = []string{}
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
|
||||
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
||||
|
||||
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
|
||||
for propName, propDef := range props {
|
||||
if propMap, ok := propDef.(map[string]any); ok {
|
||||
@@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
||||
Description: getStringFromMap(propMap, "description", ""),
|
||||
}
|
||||
tool.Function.Parameters.Properties[propName] = prop
|
||||
tool.Function.Parameters.Properties.Set(propName, prop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
38
cmd/cmd.go
38
cmd/cmd.go
@@ -45,6 +45,9 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -95,6 +98,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
@@ -456,6 +463,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -517,6 +533,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -543,6 +563,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@@ -812,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -943,6 +973,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
@@ -1751,6 +1784,11 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -291,6 +291,31 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("min version", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := ` Model
|
||||
architecture test
|
||||
parameters 7B
|
||||
quantization FP16
|
||||
requires 0.14.0
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
|
||||
@@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -18,8 +21,13 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
@@ -33,8 +41,94 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
func (p AdapterParameters) KV() KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := ggml.KV{
|
||||
kv := KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -107,7 +206,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ggml.KV) ggml.KV
|
||||
KV(ofs.Config) KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -115,7 +214,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return errors.New("unknown architecture")
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -202,6 +297,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "Olmo3ForCausalLM":
|
||||
conv = &olmoModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "NomicBertModel", "NomicBertMoEModel":
|
||||
@@ -215,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -252,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -261,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -9,7 +7,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,8 +3,6 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -55,7 +53,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -18,13 +19,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
117
convert/convert_olmo.go
Normal file
117
convert/convert_olmo.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ropeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||
AttentionFactor float32 `json:"attention_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
RopeType string `json:"rope_type"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
}
|
||||
|
||||
type olmoModel struct {
|
||||
ModelParameters
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||
}
|
||||
|
||||
if p.RopeScaling != nil {
|
||||
if p.RopeScaling.Factor > 0 {
|
||||
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
}
|
||||
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||
}
|
||||
if p.RopeScaling.AttentionFactor > 0 {
|
||||
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||
}
|
||||
if p.RopeScaling.RopeType != "" {
|
||||
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||
}
|
||||
}
|
||||
|
||||
if p.RMSNormEPS > 0 {
|
||||
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
}
|
||||
|
||||
if p.SlidingWindow > 0 {
|
||||
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||
for i, layerType := range p.LayerTypes {
|
||||
slidingPattern[i] = (layerType == "sliding_attention")
|
||||
}
|
||||
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *olmoModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"model.norm", "output_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k, v := range kv {
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV map[string]any
|
||||
BaseKV KV
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
"tool_name": "get_weather"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
|
||||
406
docs/api/anthropic-compatibility.mdx
Normal file
406
docs/api/anthropic-compatibility.mdx
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -277,6 +277,8 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
|
||||
|
||||
#### Supported features
|
||||
|
||||
@@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
|
||||
@@ -32,7 +32,9 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy"]
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -52,7 +54,9 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": ["curl"]
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -97,6 +101,7 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -139,7 +144,8 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility"
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu.md).
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
|
||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||
|
||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
||||
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||
|
||||
### GPU Selection
|
||||
|
||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
69
docs/integrations/claude-code.mdx
Normal file
69
docs/integrations/claude-code.mdx
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Linux
|
||||
title: "Linux"
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,8 +13,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
@@ -41,6 +41,7 @@ INSTRUCTION arguments
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
### REQUIRES
|
||||
|
||||
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||
|
||||
```
|
||||
REQUIRES <version>
|
||||
```
|
||||
|
||||
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
||||
|
||||
### Linux NVIDIA Troubleshooting
|
||||
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||
|
||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -11,4 +13,8 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -239,20 +241,34 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"deepseek2",
|
||||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -838,9 +854,11 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
19
go.mod
19
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.36.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -28,13 +28,17 @@ require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
@@ -44,6 +48,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
@@ -76,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
39
go.sum
39
go.sum
@@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
@@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
@@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
@@ -224,8 +233,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -255,6 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -267,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -278,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -295,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -319,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -11,6 +11,15 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
@@ -57,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
94
internal/orderedmap/orderedmap.go
Normal file
94
internal/orderedmap/orderedmap.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"iter"
|
||||
|
||||
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||
)
|
||||
|
||||
// Map is a generic ordered map that maintains insertion order.
|
||||
type Map[K comparable, V any] struct {
|
||||
om *orderedmap.OrderedMap[K, V]
|
||||
}
|
||||
|
||||
// New creates a new empty ordered map.
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
om: orderedmap.New[K, V](),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
if m == nil || m.om == nil {
|
||||
var zero V
|
||||
return zero, false
|
||||
}
|
||||
return m.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||
// but its position in the iteration order is preserved. If the key is new,
|
||||
// it is appended to the end.
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.om == nil {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
}
|
||||
m.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of entries.
|
||||
func (m *Map[K, V]) Len() int {
|
||||
if m == nil || m.om == nil {
|
||||
return 0
|
||||
}
|
||||
return m.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
if m == nil || m.om == nil {
|
||||
return
|
||||
}
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
if !yield(pair.Key, pair.Value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToMap converts to a regular Go map.
|
||||
// Note: The resulting map does not preserve order.
|
||||
func (m *Map[K, V]) ToMap() map[K]V {
|
||||
if m == nil || m.om == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[K]V, m.om.Len())
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
result[pair.Key] = pair.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||
if m == nil || m.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.om)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||
// order of keys in the JSON input.
|
||||
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
return json.Unmarshal(data, &m.om)
|
||||
}
|
||||
348
internal/orderedmap/orderedmap_test.go
Normal file
348
internal/orderedmap/orderedmap_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMap_BasicOperations(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Test empty map
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||
}
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on empty map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value, got %d", v)
|
||||
}
|
||||
|
||||
// Test Set and Get
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("b")
|
||||
if !ok || v != 2 {
|
||||
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("c")
|
||||
if !ok || v != 3 {
|
||||
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
// Test updating existing key preserves position
|
||||
m.Set("a", 10)
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 10 {
|
||||
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
m.Set("b", 4)
|
||||
|
||||
// Verify iteration order matches insertion order
|
||||
var keys []string
|
||||
var values []int
|
||||
for k, v := range m.All() {
|
||||
keys = append(keys, k)
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
expectedKeys := []string{"z", "a", "m", "b"}
|
||||
expectedValues := []int{1, 2, 3, 4}
|
||||
|
||||
if !slices.Equal(keys, expectedKeys) {
|
||||
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||
}
|
||||
if !slices.Equal(values, expectedValues) {
|
||||
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
m.Set("first", 1)
|
||||
m.Set("second", 2)
|
||||
m.Set("third", 3)
|
||||
|
||||
// Update middle element
|
||||
m.Set("second", 20)
|
||||
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Order should still be first, second, third
|
||||
expected := []string{"first", "second", "third"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON should preserve insertion order, not alphabetical
|
||||
expected := `{"z":1,"a":2,"m":3}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||
// JSON with non-alphabetical key order
|
||||
jsonData := `{"z":1,"a":2,"m":3}`
|
||||
|
||||
m := New[string, int]()
|
||||
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
expected := []string{"z", "a", "m"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||
// Test that unmarshal -> marshal produces identical JSON
|
||||
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||
|
||||
m := New[string, string]()
|
||||
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != original {
|
||||
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_ToMap(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
|
||||
regular := m.ToMap()
|
||||
|
||||
if len(regular) != 2 {
|
||||
t.Errorf("expected len 2, got %d", len(regular))
|
||||
}
|
||||
if regular["a"] != 1 {
|
||||
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||
}
|
||||
if regular["b"] != 2 {
|
||||
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NilSafety(t *testing.T) {
|
||||
var m *Map[string, int]
|
||||
|
||||
// All operations should be safe on nil
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on nil map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value from nil map, got %d", v)
|
||||
}
|
||||
|
||||
// Set on nil is a no-op
|
||||
m.Set("a", 1)
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||
}
|
||||
|
||||
// All returns empty iterator
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) != 0 {
|
||||
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||
}
|
||||
|
||||
// ToMap returns nil
|
||||
if m.ToMap() != nil {
|
||||
t.Error("expected ToMap to return nil on nil map")
|
||||
}
|
||||
|
||||
// MarshalJSON returns null
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("expected {}, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NestedValues(t *testing.T) {
|
||||
m := New[string, any]()
|
||||
m.Set("string", "hello")
|
||||
m.Set("number", 42)
|
||||
m.Set("bool", true)
|
||||
m.Set("nested", map[string]int{"x": 1})
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
m.Set("d", 4)
|
||||
|
||||
// Collect only first 2
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
if len(keys) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
expected := []string{"a", "b"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_IntegerKeys(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(3, "three")
|
||||
m.Set(1, "one")
|
||||
m.Set(2, "two")
|
||||
|
||||
var keys []int
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Should preserve insertion order, not numerical order
|
||||
expected := []int{3, 1, 2}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("existing", 999)
|
||||
|
||||
// Unmarshal should replace contents
|
||||
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
_, ok := m.Get("existing")
|
||||
if ok {
|
||||
t.Error("existing key should be gone after unmarshal")
|
||||
}
|
||||
|
||||
v, ok := m.Get("new")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Create many keys in specific order
|
||||
keys := make([]string, 100)
|
||||
for i := range 100 {
|
||||
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||
if i >= 26 {
|
||||
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||
}
|
||||
}
|
||||
|
||||
for i, k := range keys {
|
||||
m.Set(k, i)
|
||||
}
|
||||
|
||||
// Verify order preserved
|
||||
var resultKeys []string
|
||||
for k := range m.All() {
|
||||
resultKeys = append(resultKeys, k)
|
||||
}
|
||||
|
||||
if !slices.Equal(keys, resultKeys) {
|
||||
t.Error("large map should preserve insertion order")
|
||||
}
|
||||
}
|
||||
@@ -140,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
@@ -364,15 +360,12 @@ func roundUp(length, pad int) int {
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
mask := make([]float32, c.curBatchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
enabled := !slices.Contains(c.opts.Except, i)
|
||||
@@ -386,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
|
||||
@@ -14,25 +14,28 @@ make -f Makefile.sync apply-patches
|
||||
|
||||
### Updating Base Commit
|
||||
|
||||
**Pin to new base commit**
|
||||
To update to a new base commit:
|
||||
|
||||
To change the base commit, update `FETCH_HEAD` in Makefile.sync.
|
||||
1. **Update FETCH_HEAD** in `Makefile.sync` to the new commit hash.
|
||||
|
||||
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
||||
2. **Check for upstreamed patches**: Before applying, review if any patches have been merged upstream. Remove those patches from `./patches/` to avoid conflicts.
|
||||
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
3. **Apply patches**:
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
4. **Resolve conflicts** (if any): When `git am` fails on a patch:
|
||||
- Fix conflicts in `./vendor/`
|
||||
- Stage the resolved files: `git -C llama/vendor add <file>`
|
||||
- Continue: `git -C llama/vendor am --continue`
|
||||
- Re-run: `make -f Makefile.sync apply-patches`
|
||||
- Repeat until all patches are applied.
|
||||
|
||||
If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied.
|
||||
|
||||
Once all patches are applied, commit the changes to the tracking repository.
|
||||
|
||||
```shell
|
||||
make -f Makefile.sync format-patches sync
|
||||
```
|
||||
5. **Regenerate patches and sync**:
|
||||
```shell
|
||||
make -f Makefile.sync format-patches sync
|
||||
```
|
||||
|
||||
### Generating Patches
|
||||
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "17f7f4baad8b3a716ee139da7bb56ae984e8c0fa";
|
||||
char const *LLAMA_COMMIT = "b1377188784f9aea26b8abde56d4aee8c733eec7";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
@@ -17,6 +17,9 @@ include /tools/mtmd/clip.cpp
|
||||
include /tools/mtmd/mtmd.cpp
|
||||
include /tools/mtmd/mtmd-audio.cpp
|
||||
include /tools/mtmd/mtmd-helper.cpp
|
||||
include /tools/mtmd/models/
|
||||
include /tools/mtmd/models/*.h
|
||||
include /tools/mtmd/models/*.cpp
|
||||
include /src/
|
||||
include /src/llama.*
|
||||
include /src/llama-*.*
|
||||
|
||||
259
llama/llama.cpp/common/common.cpp
vendored
259
llama/llama.cpp/common/common.cpp
vendored
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
case GGML_SCHED_PRIO_REALTIME: p = -20; break;
|
||||
}
|
||||
|
||||
if (!setpriority(PRIO_PROCESS, 0, p)) {
|
||||
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
|
||||
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
@@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
|
||||
// Model utils
|
||||
//
|
||||
|
||||
static inline void common_init_sampler_from_model(
|
||||
// TODO: move to common/sampling
|
||||
static void common_init_sampler_from_model(
|
||||
const llama_model * model,
|
||||
common_params_sampling & sparams) {
|
||||
|
||||
const uint64_t config = sparams.user_sampling_config;
|
||||
|
||||
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[64] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
int32_t v = strtol(buf, &end, 10);
|
||||
if (end && end != buf) dst = v;
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[128] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
float v = strtof(buf, &end);
|
||||
if (end && end != buf) dst = v;
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1065,31 +1074,162 @@ static inline void common_init_sampler_from_model(
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||
}
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
struct common_init_result::impl {
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
// note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
|
||||
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
pimpl(new impl{}) {
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
return;
|
||||
}
|
||||
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
pimpl->model.reset(model);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
// load and optionally apply lora adapters (must be loaded before context creation)
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
pimpl->model.reset(model);
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
la.ptr = lora.get();
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
// updates params.sampling
|
||||
// TODO: fix naming
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
//if (params.sampling.penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
// init the backend samplers as part of the context creation
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
pimpl->context.reset(lctx);
|
||||
}
|
||||
|
||||
llama_model * common_init_result::model() {
|
||||
return pimpl->model.get();
|
||||
}
|
||||
|
||||
llama_context * common_init_result::context() {
|
||||
return pimpl->context.get();
|
||||
}
|
||||
|
||||
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
void common_init_result::reset_samplers() {
|
||||
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
|
||||
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
@@ -1101,10 +1241,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||
if (cvec.n_embd == -1) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
@@ -1115,10 +1252,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
params.control_vector_layer_start,
|
||||
params.control_vector_layer_end);
|
||||
if (err) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1142,67 +1276,14 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
la.ptr = lora.get();
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
if (!params.lora_init_without_apply) {
|
||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
@@ -1239,14 +1320,16 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
|
||||
// reset samplers to reset RNG state after warmup to the seeded state
|
||||
res->reset_samplers();
|
||||
}
|
||||
|
||||
iparams.model.reset(model);
|
||||
iparams.context.reset(lctx);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
common_init_result::~common_init_result() = default;
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
@@ -1255,7 +1338,9 @@ std::string get_model_endpoint() {
|
||||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
if (model_endpoint.back() != '/') {
|
||||
model_endpoint += '/';
|
||||
}
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
@@ -1276,14 +1361,12 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.devices = params.devices.data();
|
||||
}
|
||||
|
||||
if (params.n_gpu_layers != -1) {
|
||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
mparams.main_gpu = params.main_gpu;
|
||||
mparams.split_mode = params.split_mode;
|
||||
mparams.tensor_split = params.tensor_split;
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_direct_io = params.use_direct_io;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
|
||||
81
llama/llama.cpp/common/common.h
vendored
81
llama/llama.cpp/common/common.h
vendored
@@ -80,9 +80,11 @@ int32_t cpu_get_num_math();
|
||||
//
|
||||
|
||||
enum llama_example {
|
||||
LLAMA_EXAMPLE_DEBUG,
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_MAIN,
|
||||
LLAMA_EXAMPLE_COMPLETION,
|
||||
LLAMA_EXAMPLE_CLI,
|
||||
LLAMA_EXAMPLE_EMBEDDING,
|
||||
LLAMA_EXAMPLE_PERPLEXITY,
|
||||
LLAMA_EXAMPLE_RETRIEVAL,
|
||||
@@ -98,6 +100,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
LLAMA_EXAMPLE_FINETUNE,
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -194,7 +197,6 @@ struct common_params_sampling {
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
@@ -215,6 +217,12 @@ struct common_params_sampling {
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
@@ -302,8 +310,8 @@ struct lr_opt {
|
||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
@@ -324,9 +332,14 @@ struct common_params {
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||
|
||||
// margin per device in bytes for fitting parameters to free memory:
|
||||
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
@@ -362,6 +375,11 @@ struct common_params {
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
|
||||
bool save_logits = false; // whether to save logits to files // NOLINT
|
||||
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
|
||||
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
@@ -406,12 +424,14 @@ struct common_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool show_timings = true; // show timing information on CLI
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mmap = true; // enable mmap to use filesystem cache
|
||||
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
@@ -462,11 +482,12 @@ struct common_params {
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool use_jinja = true; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -475,16 +496,20 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// webui configs
|
||||
bool webui = true;
|
||||
std::string webui_config_json;
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
bool webui = true;
|
||||
bool endpoint_slots = true;
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
|
||||
bool log_json = false;
|
||||
|
||||
@@ -666,15 +691,31 @@ bool tty_can_use_colors();
|
||||
// Model utils
|
||||
//
|
||||
|
||||
// note: defines object's lifetime
|
||||
struct common_init_result {
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
struct common_sampler;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
~common_init_result();
|
||||
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
void reset_samplers();
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params);
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
||||
135
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
135
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
@@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) {
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||
|
||||
class SchemaConverter {
|
||||
class common_schema_converter {
|
||||
private:
|
||||
friend class common_schema_info;
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
@@ -729,7 +730,7 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
SchemaConverter(
|
||||
common_schema_converter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
@@ -990,6 +991,134 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// common_schema_info implementation (pimpl)
|
||||
|
||||
common_schema_info::common_schema_info()
|
||||
: impl_(std::make_unique<common_schema_converter>(
|
||||
[](const std::string &) { return json(); },
|
||||
false)) {}
|
||||
|
||||
common_schema_info::~common_schema_info() = default;
|
||||
|
||||
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
|
||||
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
|
||||
|
||||
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
|
||||
impl_->resolve_refs(schema, "");
|
||||
}
|
||||
|
||||
// Determines if a JSON schema can resolve to a string type through any path.
|
||||
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
|
||||
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
|
||||
// true, allowing callers to handle the value as a raw string for simplicity.
|
||||
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
|
||||
std::unordered_set<std::string> visited_refs;
|
||||
|
||||
std::function<bool(const json &)> check = [&](const json & s) -> bool {
|
||||
if (!s.is_object()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle $ref
|
||||
if (s.contains("$ref")) {
|
||||
const std::string & ref = s["$ref"];
|
||||
if (visited_refs.find(ref) != visited_refs.end()) {
|
||||
// Circular reference, assume not a string to be safe
|
||||
return false;
|
||||
}
|
||||
visited_refs.insert(ref);
|
||||
auto it = impl_->_refs.find(ref);
|
||||
if (it != impl_->_refs.end()) {
|
||||
return check(it->second);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check type field
|
||||
if (s.contains("type")) {
|
||||
const json & schema_type = s["type"];
|
||||
if (schema_type.is_string()) {
|
||||
if (schema_type == "string") {
|
||||
return true;
|
||||
}
|
||||
} else if (schema_type.is_array()) {
|
||||
// Type can be an array like ["string", "null"]
|
||||
for (const auto & t : schema_type) {
|
||||
if (t == "string") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check oneOf/anyOf - if any alternative can be a string
|
||||
if (s.contains("oneOf")) {
|
||||
for (const auto & alt : s["oneOf"]) {
|
||||
if (check(alt)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (s.contains("anyOf")) {
|
||||
for (const auto & alt : s["anyOf"]) {
|
||||
if (check(alt)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check allOf - all components must be compatible with string type
|
||||
if (s.contains("allOf")) {
|
||||
bool all_string = true;
|
||||
for (const auto & component : s["allOf"]) {
|
||||
if (!check(component)) {
|
||||
all_string = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_string) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check const - if the constant value is a string
|
||||
if (s.contains("const")) {
|
||||
if (s["const"].is_string()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check enum - if any enum value is a string
|
||||
if (s.contains("enum")) {
|
||||
for (const auto & val : s["enum"]) {
|
||||
if (val.is_string()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String-specific keywords imply string type
|
||||
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check format - many formats imply string
|
||||
if (s.contains("format")) {
|
||||
const std::string & fmt = s["format"];
|
||||
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
|
||||
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
|
||||
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
|
||||
fmt.find("uuid") == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return check(schema);
|
||||
}
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
if (!force_gbnf) {
|
||||
@@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
|
||||
20
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
20
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
@@ -3,11 +3,31 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
class common_schema_converter;
|
||||
|
||||
// Probes a JSON schema to extract information about its structure and type constraints.
|
||||
class common_schema_info {
|
||||
std::unique_ptr<common_schema_converter> impl_;
|
||||
|
||||
public:
|
||||
common_schema_info();
|
||||
~common_schema_info();
|
||||
|
||||
common_schema_info(const common_schema_info &) = delete;
|
||||
common_schema_info & operator=(const common_schema_info &) = delete;
|
||||
common_schema_info(common_schema_info &&) noexcept;
|
||||
common_schema_info & operator=(common_schema_info &&) noexcept;
|
||||
|
||||
void resolve_refs(nlohmann::ordered_json & schema);
|
||||
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
||||
};
|
||||
|
||||
struct common_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||
|
||||
5
llama/llama.cpp/common/log.cpp
vendored
5
llama/llama.cpp/common/log.cpp
vendored
@@ -420,6 +420,11 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_flush(struct common_log * log) {
|
||||
log->pause();
|
||||
log->resume();
|
||||
}
|
||||
|
||||
static int common_get_verbosity(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||
|
||||
1
llama/llama.cpp/common/log.h
vendored
1
llama/llama.cpp/common/log.h
vendored
@@ -84,6 +84,7 @@ void common_log_set_file (struct common_log * log, const char * file); // n
|
||||
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||
void common_log_flush (struct common_log * log); // flush all pending log messages
|
||||
|
||||
// helper macros for logging
|
||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||
|
||||
193
llama/llama.cpp/common/sampling.cpp
vendored
193
llama/llama.cpp/common/sampling.cpp
vendored
@@ -116,22 +116,38 @@ struct common_sampler {
|
||||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
||||
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
||||
cur.resize(sampled_probs_count);
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
cur.resize(sampled_logits_count);
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
@@ -160,14 +176,18 @@ std::string common_params_sampling::print() const {
|
||||
return std::string(result);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
llama_sampler * grmr = nullptr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
@@ -176,24 +196,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
trigger_patterns.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
const auto & pattern = trigger.value;
|
||||
std::string anchored = "^$";
|
||||
if (!pattern.empty()) {
|
||||
anchored = (pattern.front() != '^' ? "^" : "")
|
||||
+ pattern
|
||||
+ (pattern.back() != '$' ? "$" : "");
|
||||
}
|
||||
trigger_patterns.push_back(anchored);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
@@ -207,40 +233,26 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
}
|
||||
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
if (!grmr) {
|
||||
return nullptr;
|
||||
if (!params.grammar.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size());
|
||||
} else {
|
||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
if (params.has_logit_bias()) {
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
@@ -253,58 +265,77 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
|
||||
for (auto * smpl : samplers) {
|
||||
llama_sampler_chain_add(chain, smpl);
|
||||
}
|
||||
|
||||
if (grmr && params.backend_sampling) {
|
||||
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
|
||||
|
||||
params.backend_sampling = false;
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ chain,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
@@ -314,7 +345,7 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (accept_grammar) {
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
|
||||
@@ -329,12 +360,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -383,33 +414,56 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||
return gsmpl->chain;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// check if it the sampled token fits the grammar
|
||||
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
||||
{
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||
@@ -429,9 +483,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||
@@ -515,7 +571,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||
result += std::string("-> ");
|
||||
result += std::string(llama_sampler_name(smpl)) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
12
llama/llama.cpp/common/sampling.h
vendored
12
llama/llama.cpp/common/sampling.h
vendored
@@ -36,7 +36,8 @@ struct common_sampler;
|
||||
|
||||
// llama_sampler API overloads
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
// note: can mutate params in some cases
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
@@ -48,6 +49,9 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
// arguments can be nullptr to skip printing
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||
|
||||
// get the underlying llama_sampler_chain
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
//
|
||||
// - set logits
|
||||
@@ -107,3 +111,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||
const char * grammar_kind, const char * grammar_data);
|
||||
|
||||
struct common_sampler_deleter {
|
||||
void operator()(common_sampler * s) { common_sampler_free(s); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||
|
||||
130
llama/llama.cpp/include/llama.h
vendored
130
llama/llama.cpp/include/llama.h
vendored
@@ -286,7 +286,7 @@ extern "C" {
|
||||
// NULL-terminated list of buffer types to use for tensors that match a pattern
|
||||
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
|
||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||
|
||||
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
|
||||
@@ -309,10 +309,17 @@ extern "C" {
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_direct_io; // use direct io, takes precedence over use_mmap
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
};
|
||||
|
||||
struct llama_sampler_seq_config {
|
||||
llama_seq_id seq_id;
|
||||
struct llama_sampler * sampler;
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
@@ -363,6 +370,12 @@ extern "C" {
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -466,10 +479,31 @@ extern "C" {
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
enum llama_params_fit_status {
|
||||
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
|
||||
};
|
||||
|
||||
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||
// - returns true if the parameters could be successfully modified to fit device memory
|
||||
// - this function is NOT thread safe because it modifies the global llama logger state
|
||||
// - only parameters that have the same value as in llama_default_model_params are modified
|
||||
LLAMA_API enum llama_params_fit_status llama_params_fit(
|
||||
const char * path_model,
|
||||
struct llama_model_params * mparams,
|
||||
struct llama_context_params * cparams,
|
||||
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
||||
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
||||
size_t * margins, // margins of memory to leave per device in bytes
|
||||
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
||||
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
||||
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||
LLAMA_API size_t llama_max_tensor_buft_overrides(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
@@ -502,6 +536,7 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
@@ -585,6 +620,8 @@ extern "C" {
|
||||
//
|
||||
|
||||
// Load a LoRA adapter from file
|
||||
// The adapter is valid as long as the associated model is not freed
|
||||
// All adapters must be loaded before context creation
|
||||
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
|
||||
struct llama_model * model,
|
||||
const char * path_lora);
|
||||
@@ -968,6 +1005,32 @@ extern "C" {
|
||||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// backend sampling API [EXPERIMENTAL]
|
||||
// note: use only if the llama_context was created with at least one llama_sampler_seq_config
|
||||
//
|
||||
|
||||
// Get the backend sampled token for the ith token.
|
||||
// Returns LLAMA_TOKEN_NULL if no token was sampled.
|
||||
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled probabilites for the ith token
|
||||
// The index matches llama_get_sampled_token_ith().
|
||||
// Returns NULL if no probabilites were generated.
|
||||
LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled logits for the ith token
|
||||
// Returns NULL if no logits were sampled.
|
||||
LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled candidates (token ids) for the ith token
|
||||
// These are needed to map probability/logit indices to vocab token ids.
|
||||
// Returns NULL if no candidates were sampled.
|
||||
LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
@@ -1139,11 +1202,16 @@ extern "C" {
|
||||
//
|
||||
// llama_sampler_free(smpl);
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
//
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
||||
struct llama_sampler_data {
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * probs;
|
||||
struct ggml_tensor * sampled;
|
||||
struct ggml_tensor * candidates;
|
||||
};
|
||||
|
||||
// user code can implement the interface below in order to create custom llama_sampler
|
||||
struct llama_sampler_i {
|
||||
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
|
||||
@@ -1153,17 +1221,45 @@ extern "C" {
|
||||
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampling interface:
|
||||
|
||||
// return true if the backend supports all ops needed by the sampler
|
||||
// note: call once per sampler
|
||||
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
||||
|
||||
// call after .backend_apply()
|
||||
void (*backend_accept)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
||||
// call after .backend_init()
|
||||
void (*backend_apply)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data);
|
||||
|
||||
// called before graph execution to set inputs for the current ubatch
|
||||
void (*backend_set_input)(struct llama_sampler * smpl);
|
||||
};
|
||||
|
||||
struct llama_sampler {
|
||||
const struct llama_sampler_i * iface;
|
||||
llama_sampler_context_t ctx;
|
||||
struct llama_sampler_i * iface;
|
||||
|
||||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// attach a sampler to the context
|
||||
// note: prefer initializing the context with llama_context_params.samplers when possible
|
||||
// note: changing the samplers of a context can cause graph reallocations and degraded performance
|
||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
@@ -1179,7 +1275,15 @@ extern "C" {
|
||||
|
||||
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// return NULL if:
|
||||
// - the sampler is NULL
|
||||
// - the sampler is not a llama_sampler_chain
|
||||
// - the index is out of bounds, unless i == -1
|
||||
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// the total number of samplers in the chain
|
||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||
|
||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||
@@ -1188,7 +1292,9 @@ extern "C" {
|
||||
// available samplers:
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
|
||||
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
@@ -1354,7 +1460,9 @@ extern "C" {
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
// The logger state is global so these functions are NOT thread safe.
|
||||
LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data);
|
||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
//
|
||||
// Performance utils
|
||||
|
||||
15
llama/llama.cpp/src/llama-adapter.cpp
vendored
15
llama/llama.cpp/src/llama-adapter.cpp
vendored
@@ -146,9 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
|
||||
static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) {
|
||||
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
||||
|
||||
llama_model & model = adapter.model;
|
||||
|
||||
ggml_context * ctx_init;
|
||||
gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ true,
|
||||
@@ -411,14 +413,17 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
}
|
||||
}
|
||||
|
||||
// update number of nodes used
|
||||
model.n_lora_nodes += adapter.get_n_nodes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||
}
|
||||
|
||||
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
|
||||
llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||
llama_adapter_lora * adapter = new llama_adapter_lora(*model);
|
||||
|
||||
try {
|
||||
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
||||
llama_adapter_lora_init_impl(path_lora, *adapter);
|
||||
return adapter;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
||||
@@ -469,6 +474,10 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
|
||||
}
|
||||
|
||||
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||
// update number of nodes used
|
||||
GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes());
|
||||
adapter->model.n_lora_nodes -= adapter->get_n_nodes();
|
||||
|
||||
delete adapter;
|
||||
}
|
||||
|
||||
|
||||
8
llama/llama.cpp/src/llama-adapter.h
vendored
8
llama/llama.cpp/src/llama-adapter.h
vendored
@@ -59,6 +59,8 @@ struct llama_adapter_lora_weight {
|
||||
};
|
||||
|
||||
struct llama_adapter_lora {
|
||||
llama_model & model;
|
||||
|
||||
// map tensor name to lora_a_b
|
||||
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
|
||||
|
||||
@@ -73,10 +75,14 @@ struct llama_adapter_lora {
|
||||
// activated lora (aLoRA)
|
||||
std::vector<llama_token> alora_invocation_tokens;
|
||||
|
||||
llama_adapter_lora() = default;
|
||||
llama_adapter_lora(llama_model & model) : model(model) {}
|
||||
~llama_adapter_lora() = default;
|
||||
|
||||
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
||||
|
||||
uint32_t get_n_nodes() const {
|
||||
return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat
|
||||
}
|
||||
};
|
||||
|
||||
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||
|
||||
4251
llama/llama.cpp/src/llama-arch.cpp
vendored
4251
llama/llama.cpp/src/llama-arch.cpp
vendored
File diff suppressed because it is too large
Load Diff
19
llama/llama.cpp/src/llama-arch.h
vendored
19
llama/llama.cpp/src/llama-arch.h
vendored
@@ -3,6 +3,7 @@
|
||||
#include "ggml.h" // ggml_op
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
//
|
||||
// gguf constants (sync with gguf.py)
|
||||
@@ -23,6 +24,7 @@ enum llm_arch {
|
||||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_MODERN_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
@@ -44,6 +46,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_PLAMO2,
|
||||
LLM_ARCH_PLAMO3,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
@@ -79,6 +82,7 @@ enum llm_arch {
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
LLM_ARCH_NEMOTRON_H_MOE,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_EXAONE4,
|
||||
LLM_ARCH_RWKV6,
|
||||
@@ -117,6 +121,9 @@ enum llm_arch {
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_MIMO2,
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_MAINCODER,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -150,6 +157,7 @@ enum llm_kv {
|
||||
LLM_KV_VOCAB_SIZE,
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH_OUT,
|
||||
LLM_KV_FEATURES_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||
@@ -207,6 +215,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
@@ -218,6 +227,7 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
@@ -317,6 +327,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_DENSE_3_OUT,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||
@@ -528,6 +539,10 @@ struct LLM_TN_IMPL {
|
||||
const int bid;
|
||||
const int xid;
|
||||
|
||||
const std::set<llm_tensor> model_tensors;
|
||||
|
||||
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
|
||||
|
||||
std::string str() const;
|
||||
|
||||
operator std::string() const {
|
||||
@@ -549,11 +564,11 @@ struct LLM_TN {
|
||||
llm_arch arch;
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, suffix, bid, xid };
|
||||
return LLM_TN_IMPL(arch, tensor, suffix, bid, xid);
|
||||
}
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, nullptr, bid, xid };
|
||||
return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
14
llama/llama.cpp/src/llama-batch.cpp
vendored
14
llama/llama.cpp/src/llama-batch.cpp
vendored
@@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
||||
udata->output .resize(n_tokens);
|
||||
|
||||
udata->seq_id_data.reserve(n_tokens);
|
||||
|
||||
seq_set_t seq_set_unq;
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
@@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
}
|
||||
|
||||
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
||||
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
||||
udata->output[i] = batch.logits[idxs[i]];
|
||||
|
||||
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
||||
seq_set_unq.set(udata->seq_id[i][s]);
|
||||
const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
|
||||
|
||||
udata->seq_id_data.push_back(seq_id);
|
||||
seq_set_unq.set(seq_id);
|
||||
}
|
||||
|
||||
if (udata->output[i]) {
|
||||
@@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
}
|
||||
}
|
||||
|
||||
llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
udata->seq_id[i] = seq_id_ptr;
|
||||
seq_id_ptr += udata->n_seq_id[i];
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
udata->seq_idx[s] = udata->seq_id_unq.size();
|
||||
|
||||
6
llama/llama.cpp/src/llama-batch.h
vendored
6
llama/llama.cpp/src/llama-batch.h
vendored
@@ -56,13 +56,15 @@ struct llama_ubatch {
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
|
||||
std::vector<llama_seq_id> seq_id_data;
|
||||
};
|
||||
|
||||
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
||||
// the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
|
||||
std::shared_ptr<data_t> data;
|
||||
};
|
||||
|
||||
|
||||
11
llama/llama.cpp/src/llama-chat.cpp
vendored
11
llama/llama.cpp/src/llama-chat.cpp
vendored
@@ -74,6 +74,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
|
||||
{ "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -216,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
|
||||
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
|
||||
} else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) {
|
||||
return LLM_CHAT_TEMPLATE_SOLAR_OPEN;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -845,6 +848,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "[unused9]助手:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|begin|>assistant";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
1
llama/llama.cpp/src/llama-chat.h
vendored
1
llama/llama.cpp/src/llama-chat.h
vendored
@@ -54,6 +54,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||
LLM_CHAT_TEMPLATE_GROK_2,
|
||||
LLM_CHAT_TEMPLATE_PANGU_EMBED,
|
||||
LLM_CHAT_TEMPLATE_SOLAR_OPEN,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
766
llama/llama.cpp/src/llama-context.cpp
vendored
766
llama/llama.cpp/src/llama-context.cpp
vendored
File diff suppressed because it is too large
Load Diff
54
llama/llama.cpp/src/llama-context.h
vendored
54
llama/llama.cpp/src/llama-context.h
vendored
@@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
|
||||
size_t model = 0; // memory allocated for the model
|
||||
size_t context = 0; // memory allocated for the context
|
||||
size_t compute = 0; // memory allocated for temporary compute buffers
|
||||
|
||||
size_t total() const {
|
||||
return model + context + compute;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_context {
|
||||
@@ -66,6 +70,18 @@ struct llama_context {
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
|
||||
float * get_sampled_logits_ith(int32_t idx);
|
||||
size_t get_sampled_logits_count(int32_t idx);
|
||||
|
||||
float * get_sampled_probs_ith(int32_t idx);
|
||||
size_t get_sampled_probs_count(int32_t idx);
|
||||
|
||||
const llama_token * get_sampled_candidates_ith(int32_t idx);
|
||||
size_t get_sampled_candidates_count(int32_t idx);
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
@@ -188,10 +204,13 @@ private:
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
uint32_t output_reserve(int32_t n_outputs);
|
||||
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
|
||||
|
||||
void output_reorder();
|
||||
|
||||
// map the output row index `i` to batch index
|
||||
int64_t output_resolve_row(int32_t i) const;
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
@@ -206,7 +225,10 @@ public:
|
||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||
|
||||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
||||
ggml_cgraph * graph_reserve(
|
||||
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||
|
||||
bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
@@ -247,6 +269,31 @@ private:
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// TODO: simplify
|
||||
struct sampling_info {
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
||||
llama_token * sampled = nullptr;
|
||||
size_t sampled_size = 0;
|
||||
|
||||
float * probs = nullptr;
|
||||
size_t probs_size = 0;
|
||||
|
||||
llama_token * candidates = nullptr;
|
||||
size_t candidates_size = 0;
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
std::vector<uint32_t> candidates_count;
|
||||
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
sampling_info sampling;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
@@ -281,9 +328,10 @@ private:
|
||||
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
// buffer types used for the compute buffer of each backend
|
||||
// pointers and buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
|
||||
|
||||
llm_graph_result_ptr gf_res_prev;
|
||||
llm_graph_result_ptr gf_res_reserve;
|
||||
|
||||
53
llama/llama.cpp/src/llama-grammar.cpp
vendored
53
llama/llama.cpp/src/llama-grammar.cpp
vendored
@@ -369,6 +369,44 @@ static void print_rule(
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// Regex utilities
|
||||
//
|
||||
|
||||
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||
auto find_start_pos = [](const std::smatch & match) {
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
return start;
|
||||
};
|
||||
|
||||
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
||||
// match against the entire input
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, regex)) {
|
||||
return find_start_pos(match);
|
||||
}
|
||||
}
|
||||
|
||||
// search anywhere
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, regex)) {
|
||||
return find_start_pos(match);
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// implementation
|
||||
//
|
||||
@@ -1321,21 +1359,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||
grammar.trigger_buffer += piece;
|
||||
|
||||
std::smatch match;
|
||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
auto start = trigger_pattern.find(grammar.trigger_buffer);
|
||||
if (start != std::string::npos) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
|
||||
// replay tokens that overlap with [start, end)
|
||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||
|
||||
2
llama/llama.cpp/src/llama-grammar.h
vendored
2
llama/llama.cpp/src/llama-grammar.h
vendored
@@ -130,6 +130,8 @@ struct llama_grammar_parser {
|
||||
struct llama_grammar_trigger_pattern {
|
||||
std::string pattern;
|
||||
std::regex regex;
|
||||
|
||||
size_t find(const std::string & input) const;
|
||||
};
|
||||
|
||||
struct llama_grammar {
|
||||
|
||||
278
llama/llama.cpp/src/llama-graph.cpp
vendored
278
llama/llama.cpp/src/llama-graph.cpp
vendored
@@ -12,6 +12,7 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
@@ -32,7 +33,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -62,7 +63,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
||||
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens;
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -78,7 +79,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const float pos = ubatch->pos[i];
|
||||
attn_scale_data[i] = std::log(
|
||||
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
||||
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
|
||||
) * f_attn_temp_scale + 1.0;
|
||||
}
|
||||
|
||||
@@ -254,6 +255,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= s_copy->ne[0] == mctx->get_n_rs();
|
||||
|
||||
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= head == mctx->get_head();
|
||||
res &= rs_z == mctx->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
@@ -385,7 +404,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -416,10 +435,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
||||
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
||||
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -452,7 +471,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int i = n_tokens; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
@@ -461,8 +480,83 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
inp_attn->set_input(ubatch);
|
||||
inp_rs->set_input(ubatch);
|
||||
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
|
||||
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
// set the inputs only for the active samplers in the current ubatch
|
||||
std::unordered_set<llama_seq_id> active_samplers;
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
||||
if (ubatch->output[i]) {
|
||||
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
active_samplers.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto seq_id : active_samplers) {
|
||||
if (samplers.find(seq_id) == samplers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto & sampler = samplers[seq_id];
|
||||
|
||||
if (sampler->iface->backend_set_input) {
|
||||
sampler->iface->backend_set_input(sampler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||
if (samplers.size() != params.samplers.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||
if (samplers[seq_id] != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
@@ -485,6 +579,10 @@ void llm_graph_result::reset() {
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_sampled.clear();
|
||||
t_sampled_probs.clear();
|
||||
t_sampled_logits.clear();
|
||||
t_candidates.clear();
|
||||
|
||||
params = {};
|
||||
|
||||
@@ -509,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_result::set_outputs() {
|
||||
if (t_logits != nullptr) {
|
||||
ggml_set_output(t_logits);
|
||||
}
|
||||
if (t_embd != nullptr) {
|
||||
ggml_set_output(t_embd);
|
||||
}
|
||||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_probs) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_logits) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_candidates) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
||||
if (!this->params.allow_reuse(params)) {
|
||||
if (debug > 1) {
|
||||
@@ -590,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
res (params.res),
|
||||
ctx0 (res->get_ctx()),
|
||||
@@ -1089,6 +1220,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_relu", il);
|
||||
} break;
|
||||
case LLM_FFN_RELU_SQR:
|
||||
if (gate_exps) {
|
||||
// TODO: add support for gated squared relu
|
||||
GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||
} else {
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cur = ggml_sqr(ctx0, cur);
|
||||
cb(cur, "ffn_moe_relu_sqr", il);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -1186,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
// make sure the produced embeddings are immediately materialized in the ggml graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -1203,7 +1347,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
|
||||
|
||||
auto & cur = inp->attn_scale;
|
||||
|
||||
@@ -1470,13 +1614,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
@@ -1558,7 +1702,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1701,7 +1845,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
||||
|
||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->cross_kq_mask);
|
||||
|
||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||
@@ -1767,10 +1911,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1781,10 +1927,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
@@ -1841,6 +1989,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
||||
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
||||
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
||||
|
||||
inp->head = mctx_cur->get_head();
|
||||
inp->rs_z = mctx_cur->get_rs_z();
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
@@ -1909,10 +2060,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
@@ -1920,14 +2071,18 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
void llm_graph_context::build_dense_out(
|
||||
ggml_tensor * dense_2,
|
||||
ggml_tensor * dense_3) const {
|
||||
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
|
||||
if (!cparams.embeddings || !(dense_2 || dense_3)) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
||||
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||
if (dense_2) {
|
||||
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||
}
|
||||
if (dense_3) {
|
||||
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||
}
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
res->t_embd_pooled = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@@ -2018,6 +2173,87 @@ void llm_graph_context::build_pooling(
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
void llm_graph_context::build_sampling() const {
|
||||
if (samplers.empty() || !res->t_logits) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_logit_row[seq_id] = logit_row_idx;
|
||||
logit_row_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
||||
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
||||
|
||||
// add a dummy row of logits
|
||||
// this trick makes the graph static, regardless of which samplers are activated
|
||||
// this is important in order to minimize graph reallocations
|
||||
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
|
||||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
|
||||
// inactive samplers always work on the first row
|
||||
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_data data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
|
||||
assert(sampler->iface->backend_apply);
|
||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||
/*
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
||||
ggml_tensor * selected_token = it->second;
|
||||
if (selected_token != nullptr) {
|
||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
98
llama/llama.cpp/src/llama-graph.h
vendored
98
llama/llama.cpp/src/llama-graph.h
vendored
@@ -10,6 +10,7 @@
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
@@ -132,8 +133,8 @@ public:
|
||||
// temperature tuning, used by llama4
|
||||
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
||||
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
||||
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
|
||||
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
|
||||
virtual ~llm_graph_input_attn_temp() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -142,6 +143,7 @@ public:
|
||||
|
||||
const uint32_t n_attn_temp_floor_scale;
|
||||
const float f_attn_temp_scale;
|
||||
const float f_attn_temp_offset;
|
||||
};
|
||||
|
||||
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
||||
@@ -224,6 +226,8 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [n_rs]
|
||||
|
||||
// views of s_copy, computed once per graph
|
||||
@@ -232,6 +236,10 @@ public:
|
||||
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
|
||||
|
||||
const llama_memory_recurrent_context * mctx;
|
||||
|
||||
// used in view offsets, need to match for valid graph reuse
|
||||
uint32_t head;
|
||||
int32_t rs_z;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@@ -364,25 +372,43 @@ public:
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
inp_rs(std::move(inp_rs)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) { }
|
||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||
samplers(std::move(samplers)) { }
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@@ -416,6 +442,23 @@ struct llm_graph_params {
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
static bool samplers_equal(
|
||||
const std::map<llama_seq_id, llama_sampler *> & lhs,
|
||||
const std::map<llama_seq_id, llama_sampler *> & rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto & [seq_id, sampler] : lhs) {
|
||||
auto it = rhs.find(seq_id);
|
||||
if (it == rhs.end() || it->second != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
llm_graph_cb cb;
|
||||
@@ -455,15 +498,36 @@ struct llm_graph_params {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_outputs != other.n_outputs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!samplers_equal(samplers, other.samplers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (samplers.size() > 0) {
|
||||
if (!ubatch.data || !other.ubatch.data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check that the outputs are the same for all samplers
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (ubatch.output[i] != other.ubatch.output[i] ||
|
||||
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
n_outputs == other.n_outputs;
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -486,6 +550,7 @@ public:
|
||||
void reset();
|
||||
|
||||
void set_inputs(const llama_ubatch * ubatch);
|
||||
void set_outputs();
|
||||
|
||||
// try to update the existing graph result using the new graph parameters in order to reuse it
|
||||
// this can only be done if we determine that the resulting graph using the new graph parameters
|
||||
@@ -504,6 +569,11 @@ public:
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
@@ -579,6 +649,8 @@ struct llm_graph_context {
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
llm_graph_result * res;
|
||||
@@ -819,6 +891,12 @@ struct llm_graph_context {
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
|
||||
//
|
||||
// sampling (backend sampling)
|
||||
//
|
||||
|
||||
void build_sampling() const;
|
||||
|
||||
//
|
||||
// dense (out)
|
||||
//
|
||||
|
||||
10
llama/llama.cpp/src/llama-hparams.cpp
vendored
10
llama/llama.cpp/src/llama-hparams.cpp
vendored
@@ -1,6 +1,8 @@
|
||||
#include "llama-hparams.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
|
||||
@@ -70,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const {
|
||||
return n_embd_inp;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::get_n_embd_out() const {
|
||||
return n_embd_out > 0 ? n_embd_out : n_embd;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
@@ -237,3 +243,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::use_mrope() const {
|
||||
return rope_sections[0] > 0 && rope_sections[1] > 0;
|
||||
}
|
||||
|
||||
23
llama/llama.cpp/src/llama-hparams.h
vendored
23
llama/llama.cpp/src/llama-hparams.h
vendored
@@ -34,6 +34,7 @@ struct llama_hparams_convnext {
|
||||
|
||||
struct llama_hparams {
|
||||
bool vocab_only;
|
||||
bool no_alloc;
|
||||
bool rope_finetuned;
|
||||
bool use_par_res;
|
||||
bool swin_norm;
|
||||
@@ -106,9 +107,10 @@ struct llama_hparams {
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_base_train_swa;
|
||||
float rope_freq_base_train_swa = 10000.0f;
|
||||
float rope_freq_scale_train;
|
||||
float rope_freq_scale_train_swa;
|
||||
float rope_freq_scale_train_swa = 1.0f;
|
||||
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul = 0.0f;
|
||||
|
||||
@@ -123,10 +125,11 @@ struct llama_hparams {
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// if swa_layers[il] == 1, then layer il is SWA
|
||||
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
// note: using uint32_t type for compatibility reason
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
@@ -161,11 +164,15 @@ struct llama_hparams {
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
// output embedding dimension (0 = use n_embd)
|
||||
uint32_t n_embd_out = 0;
|
||||
|
||||
// llama4 smallthinker
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 0;
|
||||
float f_attn_temp_scale = 0.0f;
|
||||
float f_attn_temp_offset = 0.0f; // offset position index
|
||||
|
||||
// gemma3n altup
|
||||
uint32_t n_altup = 4; // altup_num_inputs
|
||||
@@ -232,6 +239,9 @@ struct llama_hparams {
|
||||
// dimension of main + auxiliary input embeddings
|
||||
uint32_t n_embd_inp() const;
|
||||
|
||||
// dimension of output embeddings
|
||||
uint32_t get_n_embd_out() const;
|
||||
|
||||
// dimension of key embeddings across all k-v heads
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
||||
|
||||
@@ -272,7 +282,8 @@ struct llama_hparams {
|
||||
// TODO: think of a better place for this function
|
||||
// TODO: pack the SWA params in a struct?
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||
|
||||
bool use_mrope() const;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
||||
|
||||
4
llama/llama.cpp/src/llama-impl.cpp
vendored
4
llama/llama.cpp/src/llama-impl.cpp
vendored
@@ -25,6 +25,10 @@ time_meas::~time_meas() {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_log_get(ggml_log_callback * log_callback, void ** user_data) {
|
||||
ggml_log_get(log_callback, user_data);
|
||||
}
|
||||
|
||||
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
ggml_log_set(log_callback, user_data);
|
||||
g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
|
||||
|
||||
132
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
132
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
@@ -175,7 +175,15 @@ llama_kv_cache::llama_kv_cache(
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
for (auto & [buft, ctx] : ctx_map) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
|
||||
ggml_backend_buffer_t buf;
|
||||
if (model.hparams.no_alloc) {
|
||||
buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
|
||||
t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
|
||||
}
|
||||
} else {
|
||||
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
|
||||
}
|
||||
if (!buf) {
|
||||
throw std::runtime_error("failed to allocate buffer for kv cache");
|
||||
}
|
||||
@@ -482,9 +490,18 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
for (const auto & [ctx, buf] : ctxs_bufs) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
|
||||
|
||||
if (hparams.no_alloc) {
|
||||
GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
|
||||
ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
|
||||
} else {
|
||||
// GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
|
||||
ret[buft] += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -1232,8 +1249,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
GGML_ASSERT(n_tokens%n_stream == 0);
|
||||
|
||||
// n_tps == n_tokens_per_stream
|
||||
const int64_t n_tps = n_tokens/n_stream;
|
||||
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
|
||||
const int64_t n_tps = n_tokens/n_stream;
|
||||
|
||||
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
||||
|
||||
@@ -1266,7 +1282,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
@@ -1370,9 +1386,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
float freq_scale) const {
|
||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||
|
||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||
@@ -1383,12 +1400,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
? LLAMA_ROPE_TYPE_NEOX
|
||||
: hparams.rope_type;
|
||||
|
||||
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
|
||||
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
||||
: cparams.yarn_attn_factor;
|
||||
|
||||
ggml_tensor * tmp;
|
||||
|
||||
if (ggml_is_quantized(cur->type)) {
|
||||
@@ -1550,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
||||
|
||||
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
||||
|
||||
slot_info sinfo;
|
||||
|
||||
bool res = true;
|
||||
res = res && state_read_meta(io, strm, cell_count, seq_id);
|
||||
res = res && state_read_data(io, strm, cell_count);
|
||||
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
|
||||
res = res && state_read_data(io, strm, cell_count, sinfo);
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
@@ -1691,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
@@ -1728,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
ubatch.seq_id[i] = &dest_seq_id;
|
||||
}
|
||||
|
||||
const auto sinfo = find_slot(ubatch, true);
|
||||
sinfo = find_slot(ubatch, false);
|
||||
if (sinfo.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
@@ -1738,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||
apply_ubatch(sinfo, ubatch);
|
||||
|
||||
const auto head_cur = sinfo.head();
|
||||
LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
|
||||
|
||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||
head = head_cur;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
|
||||
|
||||
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
||||
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
||||
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
||||
// DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
|
||||
GGML_ASSERT(sinfo.n_stream() == 1);
|
||||
GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const uint32_t idx = sinfo.idxs[0][i];
|
||||
GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
|
||||
GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
|
||||
}
|
||||
} else {
|
||||
// whole KV cache restore
|
||||
|
||||
@@ -1784,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
}
|
||||
}
|
||||
|
||||
// Create contiguous slot_info for whole cache restore
|
||||
sinfo.s0 = strm;
|
||||
sinfo.s1 = strm;
|
||||
sinfo.resize(1);
|
||||
sinfo.strm[0] = strm;
|
||||
sinfo.idxs[0].resize(cell_count);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
sinfo.idxs[0][i] = i;
|
||||
}
|
||||
|
||||
head = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
@@ -1842,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * k_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
||||
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1874,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * v_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1914,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells
|
||||
const uint32_t h = sinfo.head();
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const void * src = io.read(cell_count * v_size_el);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
23
llama/llama.cpp/src/llama-kv-cache.h
vendored
23
llama/llama.cpp/src/llama-kv-cache.h
vendored
@@ -72,6 +72,23 @@ public:
|
||||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// check if indices are contiguous starting from head()
|
||||
bool is_contiguous() const {
|
||||
if (idxs.empty() || idxs[0].empty()) {
|
||||
return true;
|
||||
}
|
||||
if (idxs.size() > 1) {
|
||||
return false;
|
||||
}
|
||||
const uint32_t h = idxs[0][0];
|
||||
for (size_t i = 0; i < idxs[0].size(); ++i) {
|
||||
if (idxs[0][i] != h + i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
@@ -264,8 +281,8 @@ private:
|
||||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
|
||||
};
|
||||
|
||||
class llama_kv_cache_context : public llama_memory_context_i {
|
||||
@@ -288,7 +305,7 @@ public:
|
||||
bool do_shift,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
|
||||
2
llama/llama.cpp/src/llama-memory-hybrid.cpp
vendored
2
llama/llama.cpp/src/llama-memory-hybrid.cpp
vendored
@@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
|
||||
209
llama/llama.cpp/src/llama-mmap.cpp
vendored
209
llama/llama.cpp/src/llama-mmap.cpp
vendored
@@ -13,9 +13,10 @@
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
@@ -74,7 +75,7 @@ struct llama_file::impl {
|
||||
return ret;
|
||||
}
|
||||
|
||||
impl(const char * fname, const char * mode) {
|
||||
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
@@ -109,7 +110,7 @@ struct llama_file::impl {
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
void read_raw(void * ptr, size_t len) {
|
||||
size_t bytes_read = 0;
|
||||
while (bytes_read < len) {
|
||||
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
|
||||
@@ -126,7 +127,7 @@ struct llama_file::impl {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
uint32_t read_u32() {
|
||||
uint32_t val;
|
||||
read_raw(&val, sizeof(val));
|
||||
return val;
|
||||
@@ -153,16 +154,55 @@ struct llama_file::impl {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
bool has_direct_io() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
#else
|
||||
impl(const char * fname, const char * mode) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) {
|
||||
#ifdef __linux__
|
||||
// Try unbuffered I/O for read only
|
||||
if (use_direct_io && std::strcmp(mode, "rb") == 0) {
|
||||
if (init_fd()) {
|
||||
return;
|
||||
}
|
||||
LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O",
|
||||
fname, strerror(errno));
|
||||
}
|
||||
#endif
|
||||
init_fp(mode);
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
bool init_fd() {
|
||||
fd = open(fname.c_str(), O_RDONLY | O_DIRECT);
|
||||
|
||||
if (fd != -1) {
|
||||
struct stat file_stats{};
|
||||
fstat(fd, &file_stats);
|
||||
|
||||
size = file_stats.st_size;
|
||||
alignment = file_stats.st_blksize;
|
||||
|
||||
off_t ret = lseek(fd, 0, SEEK_SET);
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
void init_fp(const char * mode) {
|
||||
fp = ggml_fopen(fname.c_str(), mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno)));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
@@ -170,46 +210,118 @@ struct llama_file::impl {
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||
if (fd == -1) {
|
||||
long ret = std::ftell(fp);
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||
}
|
||||
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
return (size_t) ret;
|
||||
off_t pos = lseek(fd, 0, SEEK_CUR);
|
||||
if (pos == -1) {
|
||||
throw std::runtime_error(format("lseek error: %s", strerror(errno)));
|
||||
}
|
||||
return (size_t) pos;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
if (ret != 0) {
|
||||
off_t ret = 0;
|
||||
if (fd == -1) {
|
||||
ret = std::fseek(fp, (long) offset, whence);
|
||||
} else {
|
||||
ret = lseek(fd, offset, whence);
|
||||
}
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
void read_raw_unsafe(void * ptr, size_t len) {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
if (fd == -1) {
|
||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
} else {
|
||||
size_t bytes_read = 0;
|
||||
while (bytes_read < len) {
|
||||
const size_t to_read = len - bytes_read;
|
||||
ssize_t ret = ::read(fd, reinterpret_cast<char *>(ptr) + bytes_read, to_read);
|
||||
|
||||
if (ret == -1) {
|
||||
if (errno == EINTR) {
|
||||
continue; // Interrupted by signal, retry
|
||||
}
|
||||
// Fallback to std::fread in case the DMA controller cannot access the buffer
|
||||
if (errno == EFAULT) {
|
||||
auto curr_off = tell();
|
||||
close(fd);
|
||||
fd = -1;
|
||||
alignment = 1;
|
||||
init_fp("rb");
|
||||
seek(curr_off, SEEK_SET);
|
||||
read_raw_unsafe(ptr, len);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret == 0) {
|
||||
// EOF: allow if this read was only pulling alignment padding past file end
|
||||
off_t pos = lseek(fd, 0, SEEK_CUR);
|
||||
if (pos != -1 && (size_t) pos == size) {
|
||||
std::memset(reinterpret_cast<char *>(ptr) + bytes_read, 0, len - bytes_read);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
|
||||
bytes_read += (size_t) ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
void read_aligned_chunk(void * dest, size_t size) {
|
||||
size_t offset = tell();
|
||||
off_t aligned_offset = offset & ~(alignment - 1);
|
||||
off_t offset_from_alignment = offset - aligned_offset;
|
||||
size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
void * raw_buffer = nullptr;
|
||||
int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read);
|
||||
if (ret != 0) {
|
||||
throw std::runtime_error(format("posix_memalign failed with error %d", ret));
|
||||
}
|
||||
|
||||
struct aligned_buffer_deleter {
|
||||
void operator()(void * p) const { free(p); }
|
||||
};
|
||||
std::unique_ptr<void, aligned_buffer_deleter> buffer(raw_buffer);
|
||||
|
||||
seek(aligned_offset, SEEK_SET);
|
||||
read_raw_unsafe(buffer.get(), bytes_to_read);
|
||||
|
||||
uintptr_t actual_data = reinterpret_cast<uintptr_t>(buffer.get()) + offset_from_alignment;
|
||||
memcpy(dest, reinterpret_cast<void *>(actual_data), size);
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) {
|
||||
if (has_direct_io()) {
|
||||
read_aligned_chunk(ptr, len);
|
||||
} else {
|
||||
read_raw_unsafe(ptr, len);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() {
|
||||
uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
@@ -230,23 +342,41 @@ struct llama_file::impl {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
bool has_direct_io() const {
|
||||
return fd != -1 && alignment > 1;
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
if (fd != -1) {
|
||||
close(fd);
|
||||
} else {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
int fd = -1;
|
||||
std::string fname;
|
||||
#endif
|
||||
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
size_t read_alignment() const {
|
||||
return alignment;
|
||||
}
|
||||
|
||||
size_t alignment = 1;
|
||||
|
||||
FILE * fp{};
|
||||
size_t size{};
|
||||
};
|
||||
|
||||
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
||||
llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) :
|
||||
pimpl(std::make_unique<impl>(fname, mode, use_direct_io)) {}
|
||||
llama_file::~llama_file() = default;
|
||||
|
||||
size_t llama_file::tell() const { return pimpl->tell(); }
|
||||
size_t llama_file::size() const { return pimpl->size; }
|
||||
|
||||
size_t llama_file::read_alignment() const { return pimpl->read_alignment(); }
|
||||
bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); }
|
||||
|
||||
int llama_file::file_id() const {
|
||||
#ifdef _WIN32
|
||||
return _fileno(pimpl->fp);
|
||||
@@ -260,9 +390,14 @@ int llama_file::file_id() const {
|
||||
}
|
||||
|
||||
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
||||
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
||||
void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
|
||||
#ifdef _WIN32
|
||||
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
|
||||
#else
|
||||
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); }
|
||||
#endif
|
||||
|
||||
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
||||
uint32_t llama_file::read_u32() { return pimpl->read_u32(); }
|
||||
|
||||
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
||||
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
||||
|
||||
11
llama/llama.cpp/src/llama-mmap.h
vendored
11
llama/llama.cpp/src/llama-mmap.h
vendored
@@ -3,6 +3,7 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
struct llama_file;
|
||||
struct llama_mmap;
|
||||
@@ -13,7 +14,7 @@ using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
||||
|
||||
struct llama_file {
|
||||
llama_file(const char * fname, const char * mode);
|
||||
llama_file(const char * fname, const char * mode, bool use_direct_io = false);
|
||||
~llama_file();
|
||||
|
||||
size_t tell() const;
|
||||
@@ -23,12 +24,16 @@ struct llama_file {
|
||||
|
||||
void seek(size_t offset, int whence) const;
|
||||
|
||||
void read_raw(void * ptr, size_t len) const;
|
||||
uint32_t read_u32() const;
|
||||
void read_raw(void * ptr, size_t len);
|
||||
void read_raw_unsafe(void * ptr, size_t len);
|
||||
void read_aligned_chunk(void * dest, size_t size);
|
||||
uint32_t read_u32();
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const;
|
||||
void write_u32(uint32_t val) const;
|
||||
|
||||
size_t read_alignment() const;
|
||||
bool has_direct_io() const;
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
98
llama/llama.cpp/src/llama-model-loader.cpp
vendored
98
llama/llama.cpp/src/llama-model-loader.cpp
vendored
@@ -462,6 +462,29 @@ namespace GGUFMeta {
|
||||
return get_key_or_arr(llm_kv(kid), result, n, required);
|
||||
}
|
||||
|
||||
bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) {
|
||||
const std::string key = llm_kv(kid);
|
||||
|
||||
const int id = gguf_find_key(meta.get(), key.c_str());
|
||||
|
||||
if (id < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// throw and error if type is an array
|
||||
if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return get_key(key, result, required);
|
||||
}
|
||||
|
||||
// TODO: this is not very clever - figure out something better
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
||||
@@ -472,7 +495,9 @@ llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits,
|
||||
bool use_mmap,
|
||||
bool use_direct_io,
|
||||
bool check_tensors,
|
||||
bool no_alloc,
|
||||
const llama_model_kv_override * param_overrides_p,
|
||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
||||
int trace = 0;
|
||||
@@ -503,9 +528,17 @@ llama_model_loader::llama_model_loader(
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb"));
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
use_direct_io = use_direct_io && files.back()->has_direct_io();
|
||||
|
||||
// Disable mmap in case Direct I/O is enabled and available
|
||||
if (use_direct_io && use_mmap) {
|
||||
use_mmap = false;
|
||||
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
|
||||
}
|
||||
|
||||
// Save tensors data offset of the main file.
|
||||
// For subsidiary files, `meta` tensor data offset must not be used,
|
||||
// so we build a unified tensors index for weights.
|
||||
@@ -571,7 +604,7 @@ llama_model_loader::llama_model_loader(
|
||||
}
|
||||
}
|
||||
|
||||
files.emplace_back(new llama_file(fname_split, "rb"));
|
||||
files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
// Save tensors data offset info of the shard.
|
||||
@@ -715,7 +748,9 @@ llama_model_loader::llama_model_loader(
|
||||
}
|
||||
|
||||
this->use_mmap = use_mmap;
|
||||
this->use_direct_io = use_direct_io;
|
||||
this->check_tensors = check_tensors;
|
||||
this->no_alloc = no_alloc;
|
||||
}
|
||||
|
||||
std::string llama_model_loader::get_arch_name() const {
|
||||
@@ -933,7 +968,15 @@ bool llama_model_loader::load_all_data(
|
||||
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
|
||||
// NVMe raid configurations might require more / larger buffers.
|
||||
constexpr size_t n_buffers = 4;
|
||||
constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
|
||||
|
||||
size_t alignment = 1;
|
||||
for (const auto & file : files) {
|
||||
alignment = std::max(file->read_alignment(), alignment);
|
||||
}
|
||||
|
||||
// Buffer size: balance between memory usage and I/O efficiency
|
||||
// 64MB works well for NVMe drives
|
||||
const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024;
|
||||
|
||||
std::vector<ggml_backend_buffer_t> host_buffers;
|
||||
std::vector<ggml_backend_event_t> events;
|
||||
@@ -983,6 +1026,7 @@ bool llama_model_loader::load_all_data(
|
||||
// If the backend is supported, create pinned memory buffers and events for synchronisation.
|
||||
for (size_t idx = 0; idx < n_buffers; ++idx) {
|
||||
auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
|
||||
|
||||
if (!buf) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
|
||||
ggml_backend_dev_name(dev));
|
||||
@@ -1064,6 +1108,7 @@ bool llama_model_loader::load_all_data(
|
||||
}
|
||||
} else {
|
||||
const auto & file = files.at(weight->idx);
|
||||
|
||||
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
||||
file->seek(weight->offs, SEEK_SET);
|
||||
file->read_raw(cur->data, n_size);
|
||||
@@ -1075,19 +1120,54 @@ bool llama_model_loader::load_all_data(
|
||||
} else {
|
||||
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
|
||||
if (upload_backend) {
|
||||
file->seek(weight->offs, SEEK_SET);
|
||||
size_t offset = weight->offs;
|
||||
alignment = file->read_alignment();
|
||||
size_t aligned_offset = offset & ~(alignment - 1);
|
||||
size_t offset_from_alignment = offset - aligned_offset;
|
||||
file->seek(aligned_offset, SEEK_SET);
|
||||
|
||||
// Calculate aligned read boundaries
|
||||
size_t read_start = aligned_offset;
|
||||
size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
size_t bytes_read = 0;
|
||||
size_t data_read = 0; // Actual tensor data copied (excluding padding)
|
||||
|
||||
while (bytes_read < n_size) {
|
||||
size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
|
||||
while (bytes_read < read_end - read_start) {
|
||||
size_t read_size = std::min<size_t>(buffer_size, read_end - read_start - bytes_read);
|
||||
|
||||
// Align the destination pointer within the pinned buffer
|
||||
uintptr_t ptr_dest_aligned = (reinterpret_cast<uintptr_t>(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
// Wait for previous upload to complete before reusing buffer
|
||||
ggml_backend_event_synchronize(events[buffer_idx]);
|
||||
file->read_raw(host_ptrs[buffer_idx], read_iteration);
|
||||
ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
|
||||
|
||||
// Read aligned chunk from file
|
||||
file->read_raw_unsafe(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
|
||||
|
||||
// Calculate actual data portion (excluding alignment padding)
|
||||
uintptr_t ptr_data = ptr_dest_aligned;
|
||||
size_t data_to_copy = read_size;
|
||||
|
||||
// Skip alignment padding at start of first chunk
|
||||
if (bytes_read == 0) {
|
||||
ptr_data += offset_from_alignment;
|
||||
data_to_copy -= offset_from_alignment;
|
||||
}
|
||||
|
||||
// Trim alignment padding at end of last chunk
|
||||
if (aligned_offset + bytes_read + read_size > offset + n_size) {
|
||||
data_to_copy -= (read_end - (offset + n_size));
|
||||
}
|
||||
|
||||
// Async upload actual data to GPU
|
||||
ggml_backend_tensor_set_async(upload_backend, cur,
|
||||
reinterpret_cast<void *>(ptr_data), data_read, data_to_copy);
|
||||
ggml_backend_event_record(events[buffer_idx], upload_backend);
|
||||
|
||||
bytes_read += read_iteration;
|
||||
data_read += data_to_copy;
|
||||
bytes_read += read_size;
|
||||
|
||||
++buffer_idx;
|
||||
buffer_idx %= n_buffers;
|
||||
}
|
||||
|
||||
6
llama/llama.cpp/src/llama-model-loader.h
vendored
6
llama/llama.cpp/src/llama-model-loader.h
vendored
@@ -70,7 +70,9 @@ struct llama_model_loader {
|
||||
size_t n_bytes = 0;
|
||||
|
||||
bool use_mmap = false;
|
||||
bool use_direct_io = false;
|
||||
bool check_tensors;
|
||||
bool no_alloc;
|
||||
|
||||
llama_files files;
|
||||
llama_ftype ftype;
|
||||
@@ -96,7 +98,9 @@ struct llama_model_loader {
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||
bool use_mmap,
|
||||
bool use_direct_io,
|
||||
bool check_tensors,
|
||||
bool no_alloc,
|
||||
const llama_model_kv_override * param_overrides_p,
|
||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
||||
|
||||
@@ -129,6 +133,8 @@ struct llama_model_loader {
|
||||
template<typename T>
|
||||
bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
|
||||
|
||||
bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true);
|
||||
|
||||
std::string get_arch_name() const;
|
||||
|
||||
enum llm_arch get_arch() const;
|
||||
|
||||
3
llama/llama.cpp/src/llama-model-saver.cpp
vendored
3
llama/llama.cpp/src/llama-model-saver.cpp
vendored
@@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() {
|
||||
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
|
||||
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
if (hparams.n_embd_out > 0) {
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out);
|
||||
}
|
||||
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
|
||||
|
||||
536
llama/llama.cpp/src/llama-model.cpp
vendored
536
llama/llama.cpp/src/llama-model.cpp
vendored
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user