Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
233d5c5eda refactor(agent): implement three-tier approval system with warn patterns
- Remove git commands from auto-allowlist
- Add new warn patterns tier for commands requiring explicit approval
- Move network commands and env files from deny to warn
- Add IsWarn() and containsWord() helper functions
- Enhanced git prefix extraction for granular allowlisting
- Move credential path patterns to denyPathPatterns
- UI improvements: dynamic warning messages and allowlist info
- Update tests: add TestIsWarn(), adjust expectations
2026-01-09 00:10:10 -08:00
464 changed files with 14520 additions and 60957 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -92,7 +92,7 @@ jobs:
flags: ''
- os: windows
arch: amd64
preset: 'CUDA 13 Windows'
preset: 'CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cuda-components:
- '"cudart"'
@@ -372,17 +372,13 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -48,10 +48,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if(APPLE)
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -190,21 +189,13 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS

View File

@@ -40,17 +40,7 @@
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
{
"name": "CUDA 13 Windows",
"inherits": [ "CUDA" ],
"description": "Reduced architecture set for Windows to avoid MSVC template compilation issues",
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;89-virtual;100-virtual;120-virtual",
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
@@ -148,11 +138,6 @@
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13"
},
{
"name": "CUDA 13 Windows",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13 Windows"
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],

View File

@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache git \
&& dnf install -y clang ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
@@ -149,7 +149,6 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
@@ -157,6 +156,15 @@ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -165,16 +173,12 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -182,6 +186,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -200,7 +205,7 @@ COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1 +0,0 @@
v0.4.1

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=a5bb8ba4c50257437630c136210396810741bbf7
FETCH_HEAD=ec98e2002
.PHONY: help
help:

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
```shell
go build -tags mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -322,7 +290,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -454,7 +421,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -526,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -558,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
@@ -669,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -678,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -127,20 +127,6 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -749,7 +735,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
@@ -874,20 +860,6 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Image contains a base64-encoded generated image.
// Only present for image generation models.
Image string `json:"image,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -75,9 +75,9 @@ The `-dev` flag enables:
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
```
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
export CGO_CFLAGS=-mmacosx-version-min=12.0
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
export CGO_LDFLAGS=-mmacosx-version-min=12.0
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
```

View File

@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end
@implementation AppDelegate
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) {
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES];
}
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel;

View File

@@ -35,7 +35,6 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -47,9 +46,6 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -95,88 +91,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Validate model name early to fail fast
modelName := args[0]
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -208,7 +127,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = modelName
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -538,7 +457,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -600,18 +518,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImage) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
yoloMode, _ := cmd.Flags().GetBool("yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -641,7 +550,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
}
return generateInteractive(cmd, opts)
@@ -747,11 +656,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -900,11 +805,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, arg := range args {
// Unload the model if it's running before deletion
if err := loadOrUnloadModel(cmd, &runOptions{
Model: arg,
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
}
}
@@ -1019,10 +924,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
if resp.ModelInfo != nil {
arch, _ := resp.ModelInfo["general.architecture"].(string)
if arch != "" {
rows = append(rows, []string{"", "architecture", arch})
}
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
var paramStr string
if resp.Details.ParameterSize != "" {
@@ -1032,9 +935,7 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
paramStr = format.HumanNumber(uint64(f))
}
}
if paramStr != "" {
rows = append(rows, []string{"", "parameters", paramStr})
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
@@ -1820,22 +1721,15 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -1871,11 +1765,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -1990,7 +1880,6 @@ func NewCLI() *cobra.Command {
} {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
@@ -2031,7 +1920,6 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
)
return rootCmd

View File

@@ -1547,79 +1547,6 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -1,58 +0,0 @@
package config
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
)
// Claude implements Runner for Claude Code integration
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
if model != "" {
return []string{"--model", model}
}
return nil
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL=http://localhost:11434",
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
return cmd.Run()
}

View File

@@ -1,101 +0,0 @@
package config
import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
)
func TestClaudeIntegration(t *testing.T) {
c := &Claude{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Claude Code" {
t.Errorf("String() = %q, want %q", got, "Claude Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestClaudeFindPath(t *testing.T) {
c := &Claude{}
t.Run("finds claude in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(tmpDir, ".claude", "local", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

View File

@@ -1,193 +0,0 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
)
type Clawdbot struct{}
func (c *Clawdbot) String() string { return "Clawdbot" }
const ansiGreen = "\033[32m"
func (c *Clawdbot) Run(model string) error {
if _, err := exec.LookPath("clawdbot"); err != nil {
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
}
models := []string{model}
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("clawdbot", "gateway")
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
func (c *Clawdbot) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Clawdbot) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = "http://127.0.0.1:11434/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Clawdbot) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

View File

@@ -1,625 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestClawdbotIntegration(t *testing.T) {
c := &Clawdbot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Clawdbot" {
t.Errorf("String() = %q, want %q", got, "Clawdbot")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClawdbotEdit(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelExists(t, configPath, "mistral")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
})
}
func TestClawdbotModels(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertClawdbotModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestClawdbotPaths(t *testing.T) {
c := &Clawdbot{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestClawdbotModelsEdgeCases(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestClawdbotEditSchemaFields(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestClawdbotEditModelNames(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "library/model:tag")
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "test-model")
})
}
func TestClawdbotEditAgentsPreservation(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testClawdbotFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestClawdbotEdit_RoundTrip(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestClawdbotEdit_Idempotent(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestClawdbotEdit_BackupCreated(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Clawdbot{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}

View File

@@ -1,61 +0,0 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
return args
}
func (c *Codex) Run(model string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

View File

@@ -1,28 +0,0 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}

View File

@@ -1,115 +0,0 @@
// Package config provides integration configuration for external coding tools
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
}
return nil, err
}
var cfg config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
}
if cfg.Integrations == nil {
cfg.Integrations = make(map[string]*integration)
}
return &cfg, nil
}
func save(cfg *config) error {
path, err := configPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return writeWithBackup(path, data)
}
func saveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
}
return result, nil
}

View File

@@ -1,373 +0,0 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if len(config.Models) != len(models) {
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
}
for i, m := range models {
if config.Models[i] != m {
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
}
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-a" {
t.Errorf("expected model-a, got %s", defaultModel)
}
})
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
config := &integration{Models: []string{}}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "" {
t.Errorf("expected empty string, got %s", defaultModel)
}
})
t.Run("app name is case-insensitive", func(t *testing.T) {
saveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-x" {
t.Errorf("expected model-x, got %s", defaultModel)
}
})
t.Run("multiple integrations in single file", func(t *testing.T) {
saveIntegration("app1", []string{"model-1"})
saveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
defaultModel1 = config1.Models[0]
}
defaultModel2 := ""
if len(config2.Models) > 0 {
defaultModel2 = config2.Models[0]
}
if defaultModel1 != "model-1" {
t.Errorf("expected model-1, got %s", defaultModel1)
}
if defaultModel2 != "model-2" {
t.Errorf("expected model-2, got %s", defaultModel2)
}
})
}
func TestListIntegrations(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty when no integrations", func(t *testing.T) {
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 0 {
t.Errorf("expected 0 integrations, got %d", len(configs))
}
})
t.Run("returns all saved integrations", func(t *testing.T) {
saveIntegration("claude", []string{"model-1"})
saveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 2 {
t.Errorf("expected 2 integrations, got %d", len(configs))
}
})
}
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
}
func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := loadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
if config.Models == nil {
// nil is acceptable
} else if len(config.Models) != 0 {
t.Errorf("expected empty or nil models, got %v", config.Models)
}
}
func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := saveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
}
}
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}
if !os.IsNotExist(err) {
t.Logf("error type is os.ErrNotExist as expected: %v", err)
}
}
func TestConfigPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
path, err := configPath()
if err != nil {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
}
func TestLoad(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty config when file does not exist", func(t *testing.T) {
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg == nil {
t.Fatal("expected non-nil config")
}
if cfg.Integrations == nil {
t.Error("expected non-nil Integrations map")
}
if len(cfg.Integrations) != 0 {
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
}
})
t.Run("loads existing config", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["test"] == nil {
t.Fatal("expected test integration")
}
if len(cfg.Integrations["test"].Models) != 1 {
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
}
})
t.Run("returns error for corrupted JSON", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{corrupted`), 0o644)
_, err := load()
if err == nil {
t.Error("expected error for corrupted JSON")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("creates config file", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"test": {Models: []string{"model-a", "model-b"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
path, _ := configPath()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("config file was not created")
}
})
t.Run("round-trip preserves data", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"claude": {Models: []string{"llama3.2", "mistral"}},
"codex": {Models: []string{"qwen2.5"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
loaded, err := load()
if err != nil {
t.Fatal(err)
}
if len(loaded.Integrations) != 2 {
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
}
if loaded.Integrations["claude"] == nil {
t.Error("missing claude integration")
}
if len(loaded.Integrations["claude"].Models) != 2 {
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
}
})
}

View File

@@ -1,184 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidSettings represents the Droid settings.json file (only fields we use)
type droidSettings struct {
CustomModels []modelEntry `json:"customModels"`
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
}
type sessionSettings struct {
Model string `json:"model"`
ReasoningEffort string `json:"reasoningEffort"`
}
type modelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Provider string `json:"provider"`
MaxOutputTokens int `json:"maxOutputTokens"`
SupportsImages bool `json:"supportsImages"`
ID string `json:"id"`
Index int `json:"index"`
}
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (d *Droid) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".factory", "settings.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (d *Droid) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
settingsPath := filepath.Join(home, ".factory", "settings.json")
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
return err
}
// Read file once, unmarshal twice:
// map preserves unknown fields for writing back (including extra fields in model entries)
settingsMap := make(map[string]any)
var settings droidSettings
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settingsMap); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
if rawModels, ok := settingsMap["customModels"].([]any); ok {
for _, raw := range rawModels {
if m, ok := raw.(map[string]any); ok {
if m["apiKey"] != "ollama" {
nonOllamaModels = append(nonOllamaModels, raw)
}
}
}
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var newModels []any
var defaultModelID string
for i, model := range models {
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,
})
if i == 0 {
defaultModelID = modelID
}
}
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
// Update session default settings (preserve unknown fields in the nested object)
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
sessionSettings["reasoningEffort"] = "none"
}
settingsMap["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
}
func (d *Droid) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
var settings droidSettings
if err := json.Unmarshal(data, &settings); err != nil {
return nil
}
var result []string
for _, m := range settings.CustomModels {
if m.APIKey == "ollama" {
result = append(result, m.Model)
}
}
return result
}
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,99 +0,0 @@
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
func readJSONFile(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
func copyFile(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, info.Mode().Perm())
}
func backupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups")
}
func backupToTmp(srcPath string) (string, error) {
dir := backupDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
if err := copyFile(srcPath, backupPath); err != nil {
return "", err
}
return backupPath, nil
}
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func writeWithBackup(path string, data []byte) error {
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {
if !bytes.Equal(existingContent, data) {
backupPath, err = backupToTmp(path)
if err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("create temp failed: %w", err)
}
tmpPath := tmp.Name()
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("write failed: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("sync failed: %w", err)
}
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("close failed: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
_ = os.Remove(tmpPath)
if backupPath != "" {
_ = copyFile(backupPath, path)
}
return fmt.Errorf("rename failed: %w", err)
}
return nil
}

View File

@@ -1,502 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return data
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir()
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var result map[string]string
if err := json.Unmarshal(content, &result); err != nil {
t.Fatal(err)
}
if result["key"] != "value" {
t.Errorf("expected value, got %s", result["key"])
}
})
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(backupDir())
if err != nil {
t.Fatal("backup directory not created")
}
var foundBackup bool
for _, entry := range entries {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
json.Unmarshal(backup, &backupData)
if backupData["original"] {
foundBackup = true
os.Remove(backupPath)
break
}
}
}
}
}
if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups")
}
current, _ := os.ReadFile(path)
var currentData map[string]bool
json.Unmarshal(current, &currentData)
if !currentData["updated"] {
t.Error("file doesn't contain updated data")
}
})
t.Run("no backup for new file", func(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
}
}
})
t.Run("no backup when content unchanged", func(t *testing.T) {
path := filepath.Join(tmpDir, "unchanged.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(backupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countBefore++
}
}
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(backupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countAfter++
}
}
if countAfter != countBefore {
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
}
})
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
path := filepath.Join(tmpDir, "timestamped.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
timestamp := name[len("timestamped.json."):]
for _, c := range timestamp {
if c < '0' || c > '9' {
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
}
}
found = true
os.Remove(filepath.Join(backupDir(), name))
break
}
}
if !found {
t.Error("backup file with timestamp not found")
}
})
}
// Edge case tests for files.go
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
// User could lose their config with no way to recover.
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
// Create original file
originalContent := []byte(`{"original": true}`)
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := backupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
t.Error("expected error when backup fails, got nil")
}
// Original file should be preserved
current, _ := os.ReadFile(path)
if string(current) != string(originalContent) {
t.Errorf("original file was modified despite backup failure: got %s", string(current))
}
}
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
// Common issue when config owned by root or wrong perms.
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
os.MkdirAll(readOnlyDir, 0o755)
os.Chmod(readOnlyDir, 0o444)
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
}
}
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
// Documents what happens if user symlinks their config file.
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := t.TempDir()
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
// Create real file and symlink
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
os.Symlink(realFile, symlink)
// Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
// The real file should be updated (symlink followed for temp file creation)
content, _ := os.ReadFile(symlink)
if string(content) != `{"v": 2}` {
t.Errorf("symlink target not updated correctly: got %s", string(content))
}
}
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir()
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
backupPath, err := backupToTmp(path)
if err != nil {
t.Fatalf("backupToTmp with special chars failed: %v", err)
}
// Verify backup exists and has correct content
content, err := os.ReadFile(backupPath)
if err != nil {
t.Fatalf("could not read backup: %v", err)
}
if string(content) != `{"test": true}` {
t.Errorf("backup content mismatch: got %s", string(content))
}
os.Remove(backupPath)
}
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
func TestCopyFile_PreservesPermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
// Create source with specific permissions
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile failed: %v", err)
}
srcInfo, _ := os.Stat(src)
dstInfo, _ := os.Stat(dst)
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
}
}
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
err := copyFile(src, dst)
if err == nil {
t.Error("expected error for nonexistent source, got nil")
}
}
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
}
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("could not read file: %v", err)
}
if len(content) != 0 {
t.Errorf("expected empty file, got %d bytes", len(content))
}
}
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
// cannot be read (for backup comparison) but directory is writable.
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
os.Chmod(path, 0o000)
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
}
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
// Verify final content
content, _ := os.ReadFile(path)
if string(content) != `{"v": 3}` {
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
}
// Verify at least one backup exists
entries, _ := os.ReadDir(backupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
backupCount++
}
}
if backupCount == 0 {
t.Error("expected at least one backup file from rapid writes")
}
}
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test modifies system temp directory")
}
// Create a file at the backup directory path
backupPath := backupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
defer func() {
os.Remove(backupPath)
os.MkdirAll(backupPath, 0o755)
}()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
}
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
// Count existing temp files
countTempFiles := func() int {
entries, _ := os.ReadDir(tmpDir)
count := 0
for _, e := range entries {
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
count++
}
}
return count
}
before := countTempFiles()
// Create a file, then make directory read-only to cause rename failure
path := filepath.Join(tmpDir, "orphan.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
// Make a subdirectory and try to write there after making parent read-only
subDir := filepath.Join(tmpDir, "subdir")
os.MkdirAll(subDir, 0o755)
subPath := filepath.Join(subDir, "config.json")
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
// Make subdir read-only after creating temp file would succeed but rename would fail
// This is tricky to test - the temp file is created in the same dir, so if we can't
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
// Force a failure by making the target a directory
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
}
}

View File

@@ -1,355 +0,0 @@
package config
import (
"context"
"errors"
"fmt"
"maps"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
// Runners execute the launching of a model with the integration - claude, codex
// Editors can edit config files (supports multi-model selection) - opencode, droid
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
// String returns the human-readable name of the integration
String() string
}
// Editor can edit config files (supports multi-model selection)
type Editor interface {
// Paths returns the paths to the config files for the integration
Paths() []string
// Edit updates the config files for the integration with the given models
Edit(models []string) error
// Models returns the models currently configured for the integration
// TODO(parthsareen): add error return to Models()
Models() []string
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Clawdbot{},
"codex": &Codex{},
"droid": &Droid{},
"opencode": &OpenCode{},
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
}
items = append(items, selectItem{Name: name, Description: description})
}
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported integrations:
claude Claude Code
clawdbot Clawdbot
codex Codex
droid Droid
opencode OpenCode
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
var name string
if len(args) > 0 {
name = args[0]
} else {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0])
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}

View File

@@ -1,188 +0,0 @@
package config
import (
"slices"
"strings"
"testing"
"github.com/spf13/cobra"
)
func TestIntegrationLookup(t *testing.T) {
tests := []struct {
name string
input string
wantFound bool
wantName string
}{
{"claude lowercase", "claude", true, "Claude Code"},
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
{"empty string", "", false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, found := integrations[strings.ToLower(tt.input)]
if found != tt.wantFound {
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
}
if found && r.String() != tt.wantName {
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
}
})
}
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
r, ok := integrations[name]
if !ok {
t.Fatalf("integration %q not found in registry", name)
}
if r.String() == "" {
t.Error("integration.String() should not be empty")
}
})
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string
models []string
want bool
}{
{"empty list", []string{}, false},
{"single local model", []string{"llama3.2"}, true},
{"single cloud model", []string{"cloud-model"}, false},
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
{"local model first", []string{"llama3.2", "cloud-model"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
}
})
}
}
func TestLaunchCmd(t *testing.T) {
// Mock checkServerHeartbeat that always succeeds
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
modelFlag := cmd.Flags().Lookup("model")
if modelFlag == nil {
t.Error("--model flag should exist")
}
configFlag := cmd.Flags().Lookup("config")
if configFlag == nil {
t.Error("--config flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
tests := []struct {
name string
models []string
want bool
reason string
}{
{"empty list", []string{}, false, "empty list has no local models"},
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.ContainsFunc(tt.models, func(m string) bool {
return !strings.Contains(m, "cloud")
})
if got != tt.want {
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
}
})
}
}
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
// PreRunE should be nil when passed nil
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
})
}
}

View File

@@ -1,224 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
}
return paths
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": "http://localhost:11434/v1",
},
}
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
continue
}
models[model] = map[string]any{
"name": model,
"_launch": true,
}
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
}
return false
}

View File

@@ -1,507 +0,0 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
cleanup := func() {
os.RemoveAll(configDir)
os.RemoveAll(stateDir)
}
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
if provider["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve other models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("update existing model", func(t *testing.T) {
cleanup()
o.Edit([]string{"llama3.2"})
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["keybindings"] == nil {
t.Error("keybindings was removed")
}
})
t.Run("model state - insert at index 0", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
})
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
if len(state["favorite"].([]any)) != 1 {
t.Error("favorite was modified")
}
if state["variant"].(map[string]any)["a"] != "b" {
t.Error("variant was modified")
}
})
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
recent := state["recent"].([]any)
if len(recent) != 2 {
t.Errorf("expected 2 recent entries, got %d", len(recent))
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("remove model", func(t *testing.T) {
cleanup()
// First add two models
o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral")
// Then remove one by only selecting the other
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("preserve user customizations on managed models", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Add custom fields to the model entry (simulating user edits)
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
entry["_myPref"] = "custom-value"
entry["_myNum"] = 42
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit — should preserve custom fields
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
provider = cfg["provider"].(map[string]any)
ollama = provider["ollama"].(map[string]any)
models = ollama["models"].(map[string]any)
entry = models["llama3.2"].(map[string]any)
if entry["_myPref"] != "custom-value" {
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
}
if entry["_myNum"] != float64(42) {
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
}
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
}
})
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
cleanup()
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
// _launch marker should be added
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
}
// [Ollama] suffix should be stripped
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
t.Errorf("name suffix not stripped: got %q", entry["name"])
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Add a non-Ollama model manually
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
})
}
func assertOpenCodeModelExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatal("provider not found")
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
t.Fatal("ollama provider not found")
}
models, ok := ollama["models"].(map[string]any)
if !ok {
t.Fatal("models not found")
}
if models[model] == nil {
t.Errorf("model %s not found", model)
}
}
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
return // No provider means no model
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
return // No ollama means no model
}
models, ok := ollama["models"].(map[string]any)
if !ok {
return // No models means no model
}
if models[model] != nil {
t.Errorf("model %s should not exist but was found", model)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Fatal(err)
}
recent, ok := state["recent"].([]any)
if !ok {
t.Fatal("recent not found")
}
if index >= len(recent) {
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
}
entry, ok := recent[index].(map[string]any)
if !ok {
t.Fatal("entry is not a map")
}
if entry["providerID"] != providerID {
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
}
if entry["modelID"] != modelID {
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
}
}
// Edge case tests for opencode.go
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
// Should not panic - corrupted JSON should be treated as empty
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted config: %v", err)
}
// Verify valid JSON was created
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Errorf("resulting config is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted state: %v", err)
}
// Verify valid state was created
data, _ := os.ReadFile(statePath)
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("resulting state is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type provider failed: %v", err)
}
// Verify provider is now correct type
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
}
if provider["ollama"] == nil {
t.Error("ollama provider should be created")
}
}
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type recent failed: %v", err)
}
// The function should handle this gracefully
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
// recent should be properly set after setup
recent, ok := state["recent"].([]any)
if !ok {
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
} else if len(recent) == 0 {
t.Logf("Note: recent is empty (documenting behavior)")
}
}
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
os.WriteFile(configPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := o.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(configPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Model name with special characters (though unusual)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
// Verify it was stored correctly
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
// Model should be accessible
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := o.Models()
if len(models) > 0 {
t.Errorf("expected nil/empty for missing config, got %v", models)
}
}

View File

@@ -1,499 +0,0 @@
package config
import (
"errors"
"fmt"
"io"
"os"
"strings"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiHideCursor = "\033[?25l"
ansiShowCursor = "\033[?25h"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiClearDown = "\033[J"
)
const maxDisplayedItems = 10
var errCancelled = errors.New("cancelled")
type selectItem struct {
Name string
Description string
}
type inputEvent int
const (
eventNone inputEvent = iota
eventEnter
eventEscape
eventUp
eventDown
eventTab
eventBackspace
eventChar
)
type selectState struct {
items []selectItem
filter string
selected int
scrollOffset int
}
func newSelectState(items []selectItem) *selectState {
return &selectState{items: items}
}
func (s *selectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.selected = 0
s.scrollOffset = 0
}
case eventUp:
if s.selected > 0 {
s.selected--
if s.selected < s.scrollOffset {
s.scrollOffset = s.selected
}
}
case eventDown:
if s.selected < len(filtered)-1 {
s.selected++
if s.selected >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.selected - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.selected = 0
s.scrollOffset = 0
}
return false, "", nil
}
type multiSelectState struct {
items []selectItem
itemIndex map[string]int
filter string
highlighted int
scrollOffset int
checked map[int]bool
checkOrder []int
focusOnButton bool
}
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
s := &multiSelectState{
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
s.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := s.itemIndex[name]; ok {
s.checked[idx] = true
s.checkOrder = append(s.checkOrder, idx)
}
}
return s
}
func (s *multiSelectState) filtered() []selectItem {
return filterItems(s.items, s.filter)
}
func (s *multiSelectState) toggleItem() {
filtered := s.filtered()
if len(filtered) == 0 || s.highlighted >= len(filtered) {
return
}
item := filtered[s.highlighted]
origIdx := s.itemIndex[item.Name]
if s.checked[origIdx] {
delete(s.checked, origIdx)
for i, idx := range s.checkOrder {
if idx == origIdx {
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
break
}
}
} else {
s.checked[origIdx] = true
s.checkOrder = append(s.checkOrder, origIdx)
}
}
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
filtered := s.filtered()
switch event {
case eventEnter:
if s.focusOnButton && len(s.checkOrder) > 0 {
var res []string
for _, idx := range s.checkOrder {
res = append(res, s.items[idx].Name)
}
return true, res, nil
} else if !s.focusOnButton {
s.toggleItem()
}
case eventTab:
if len(s.checkOrder) > 0 {
s.focusOnButton = !s.focusOnButton
}
case eventEscape:
return true, nil, errCancelled
case eventBackspace:
if len(s.filter) > 0 {
s.filter = s.filter[:len(s.filter)-1]
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
case eventUp:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted > 0 {
s.highlighted--
if s.highlighted < s.scrollOffset {
s.scrollOffset = s.highlighted
}
}
case eventDown:
if s.focusOnButton {
s.focusOnButton = false
} else if s.highlighted < len(filtered)-1 {
s.highlighted++
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
}
}
case eventChar:
s.filter += string(char)
s.highlighted = 0
s.scrollOffset = 0
s.focusOnButton = false
}
return false, nil, nil
}
func (s *multiSelectState) selectedCount() int {
return len(s.checkOrder)
}
// Terminal I/O handling
type terminalState struct {
fd int
oldState *term.State
}
func enterRawMode() (*terminalState, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return nil, err
}
fmt.Fprint(os.Stderr, ansiHideCursor)
return &terminalState{fd: fd, oldState: oldState}, nil
}
func (t *terminalState) restore() {
fmt.Fprint(os.Stderr, ansiShowCursor)
term.Restore(t.fd, t.oldState)
}
func clearLines(n int) {
if n > 0 {
fmt.Fprintf(os.Stderr, "\033[%dA", n)
fmt.Fprint(os.Stderr, ansiClearDown)
}
}
func parseInput(r io.Reader) (inputEvent, byte, error) {
buf := make([]byte, 3)
n, err := r.Read(buf)
if err != nil {
return 0, 0, err
}
switch {
case n == 1 && buf[0] == 13:
return eventEnter, 0, nil
case n == 1 && (buf[0] == 3 || buf[0] == 27):
return eventEscape, 0, nil
case n == 1 && buf[0] == 9:
return eventTab, 0, nil
case n == 1 && buf[0] == 127:
return eventBackspace, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
return eventUp, 0, nil
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
return eventDown, 0, nil
case n == 1 && buf[0] >= 32 && buf[0] < 127:
return eventChar, buf[0], nil
}
return eventNone, 0, nil
}
// Rendering
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
prefix := " "
if idx == s.selected {
prefix = " " + ansiBold + "> "
}
if item.Description != "" {
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
} else {
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
return lineCount
}
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)
for i := range displayCount {
idx := s.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := s.itemIndex[item.Name]
checkbox := "[ ]"
if s.checked[origIdx] {
checkbox = "[x]"
}
prefix := " "
suffix := ""
if idx == s.highlighted && !s.focusOnButton {
prefix = "> "
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
suffix = " " + ansiGray + "(default)" + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
}
lineCount++
}
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
lineCount++
}
}
fmt.Fprintf(w, "\r\n")
lineCount++
count := s.selectedCount()
switch {
case count == 0:
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
case s.focusOnButton:
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
default:
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
}
lineCount++
return lineCount
}
// selectPrompt prompts the user to select a single item from a list.
func selectPrompt(prompt string, items []selectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return "", err
}
defer ts.restore()
state := newSelectState(items)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return "", err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return "", err
}
return result, nil
}
render()
}
}
// multiSelectPrompt prompts the user to select multiple items from a list.
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
ts, err := enterRawMode()
if err != nil {
return nil, err
}
defer ts.restore()
state := newMultiSelectState(items, preChecked)
var lastLineCount int
render := func() {
clearLines(lastLineCount)
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
}
render()
for {
event, char, err := parseInput(os.Stdin)
if err != nil {
return nil, err
}
done, result, err := state.handleInput(event, char)
if done {
clearLines(lastLineCount)
if err != nil {
return nil, err
}
return result, nil
}
render()
}
}
func confirmPrompt(prompt string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}
func filterItems(items []selectItem, filter string) []selectItem {
if filter == "" {
return items
}
var result []selectItem
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}

View File

@@ -1,913 +0,0 @@
package config
import (
"bytes"
"strings"
"testing"
)
func TestFilterItems(t *testing.T) {
items := []selectItem{
{Name: "llama3.2:latest"},
{Name: "qwen2.5:7b"},
{Name: "deepseek-v3:cloud"},
{Name: "GPT-OSS:20b"},
}
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
result := filterItems(items, "")
if len(result) != len(items) {
t.Errorf("expected %d items, got %d", len(items), len(result))
}
})
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
result := filterItems(items, "LLAMA")
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
t.Errorf("expected llama3.2:latest, got %v", result)
}
})
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
result := filterItems(items, "gpt")
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
t.Errorf("expected GPT-OSS:20b, got %v", result)
}
})
t.Run("PartialMatch", func(t *testing.T) {
result := filterItems(items, "deep")
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
t.Errorf("expected deepseek-v3:cloud, got %v", result)
}
})
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
result := filterItems(items, "nonexistent")
if len(result) != 0 {
t.Errorf("expected 0 items, got %d", len(result))
}
})
}
func TestSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState", func(t *testing.T) {
s := newSelectState(items)
if s.selected != 0 {
t.Errorf("expected selected=0, got %d", s.selected)
}
if s.filter != "" {
t.Errorf("expected empty filter, got %q", s.filter)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item1" || err != nil {
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
s := newSelectState(items)
s.filter = "item3"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "item3" || err != nil {
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != "" || err != errCancelled {
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Down_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventDown, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventDown, 0)
if s.selected != 2 {
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
}
})
t.Run("Up_MovesSelection", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventUp, 0)
if s.selected != 1 {
t.Errorf("expected selected=1, got %d", s.selected)
}
})
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventChar, 'i')
s.handleInput(eventChar, 't')
s.handleInput(eventChar, 'e')
s.handleInput(eventChar, 'm')
s.handleInput(eventChar, '2')
if s.filter != "item2" {
t.Errorf("expected filter='item2', got %q", s.filter)
}
filtered := s.filtered()
if len(filtered) != 1 || filtered[0].Name != "item2" {
t.Errorf("expected [item2], got %v", filtered)
}
})
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.selected = 2
s.handleInput(eventChar, 'x')
if s.selected != 0 {
t.Errorf("expected selected=0 after typing, got %d", s.selected)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState(items)
s.handleInput(eventBackspace, 0)
if s.filter != "" {
t.Errorf("expected filter='', got %q", s.filter)
}
})
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
s := newSelectState(items)
s.filter = "test"
s.selected = 2
s.handleInput(eventBackspace, 0)
if s.selected != 0 {
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
}
})
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
// maxDisplayedItems is 10, so with 15 items we need to scroll
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
// move down 12 times (past the 10-item viewport)
for range 12 {
s.handleInput(eventDown, 0)
}
if s.selected != 12 {
t.Errorf("expected selected=12, got %d", s.selected)
}
if s.scrollOffset != 3 {
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
}
})
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
s.selected = 5
s.scrollOffset = 5
s.handleInput(eventUp, 0)
if s.selected != 4 {
t.Errorf("expected selected=4, got %d", s.selected)
}
if s.scrollOffset != 4 {
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
}
})
}
func TestMultiSelectState(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
{Name: "item3"},
}
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected, got %d", s.selectedCount())
}
if s.focusOnButton {
t.Error("expected focusOnButton=false initially")
}
})
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item3"})
if s.selectedCount() != 2 {
t.Errorf("expected 2 selected, got %d", s.selectedCount())
}
if !s.checked[1] || !s.checked[2] {
t.Error("expected item2 and item3 to be checked")
}
})
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
// order matters: first checked = default model
s := newMultiSelectState(items, []string{"item3", "item1"})
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
}
})
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
if s.selectedCount() != 1 {
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
}
})
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.toggleItem()
if !s.checked[0] {
t.Error("expected item1 to be checked after toggle")
}
})
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.toggleItem()
if s.checked[0] {
t.Error("expected item1 to be unchecked after toggle")
}
})
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
s.highlighted = 1 // toggle item2
s.toggleItem()
if len(s.checkOrder) != 2 {
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
}
// should be [0, 2] (item1, item3) with item2 removed
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
}
})
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventEnter, 0)
if !s.checked[0] {
t.Error("expected item1 to be checked after enter")
}
})
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if !done || err != nil {
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
}
// result should preserve selection order
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
t.Errorf("expected [item2, item1], got %v", result)
}
})
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.focusOnButton = true
done, result, err := s.handleInput(eventEnter, 0)
if done || result != nil || err != nil {
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after tab")
}
})
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("tab should not focus button when nothing selected")
}
})
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.handleInput(eventTab, 0)
if !s.focusOnButton {
t.Error("expected focus on button after first tab")
}
s.handleInput(eventTab, 0)
if s.focusOnButton {
t.Error("expected focus back on list after second tab")
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
done, result, err := s.handleInput(eventEscape, 0)
if !done || result != nil || err != errCancelled {
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
}
})
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item2", "item1"})
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
t.Error("expected item2 (idx 1) to be default (first checked)")
}
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected item1 (idx 0) to NOT be default")
}
})
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
s := newMultiSelectState(items, nil)
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
t.Error("expected isDefault=false when nothing checked")
}
})
t.Run("Down_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventDown, 0)
if s.highlighted != 1 {
t.Errorf("expected highlighted=1, got %d", s.highlighted)
}
})
t.Run("Up_MovesHighlight", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.highlighted = 1
s.handleInput(eventUp, 0)
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
})
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.focusOnButton = true
s.handleInput(eventDown, 0)
if s.focusOnButton {
t.Error("expected focus to return to list on arrow key")
}
})
t.Run("Char_AppendsToFilter", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.handleInput(eventChar, 'x')
if s.filter != "x" {
t.Errorf("expected filter='x', got %q", s.filter)
}
})
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newMultiSelectState(manyItems, nil)
s.highlighted = 10
s.scrollOffset = 5
s.handleInput(eventChar, 'x')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0, got %d", s.highlighted)
}
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
}
})
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
s := newMultiSelectState(items, nil)
s.filter = "test"
s.handleInput(eventBackspace, 0)
if s.filter != "tes" {
t.Errorf("expected filter='tes', got %q", s.filter)
}
})
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
s.filter = "x"
s.focusOnButton = true
s.handleInput(eventBackspace, 0)
if s.focusOnButton {
t.Error("expected focusOnButton=false after backspace")
}
})
}
func TestParseInput(t *testing.T) {
t.Run("Enter", func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{13}))
if err != nil || event != eventEnter || char != 0 {
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
}
})
t.Run("Escape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape, got %v", event)
}
})
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{3}))
if err != nil || event != eventEscape {
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
}
})
t.Run("Tab", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{9}))
if err != nil || event != eventTab {
t.Errorf("expected eventTab, got %v", event)
}
})
t.Run("Backspace", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{127}))
if err != nil || event != eventBackspace {
t.Errorf("expected eventBackspace, got %v", event)
}
})
t.Run("UpArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
if err != nil || event != eventUp {
t.Errorf("expected eventUp, got %v", event)
}
})
t.Run("DownArrow", func(t *testing.T) {
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
if err != nil || event != eventDown {
t.Errorf("expected eventDown, got %v", event)
}
})
t.Run("PrintableChars", func(t *testing.T) {
tests := []struct {
name string
char byte
}{
{"lowercase", 'a'},
{"uppercase", 'Z'},
{"digit", '5'},
{"space", ' '},
{"tilde", '~'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
if err != nil || event != eventChar || char != tt.char {
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
}
})
}
})
}
func TestRenderSelect(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "first item"},
{Name: "item2"},
}
t.Run("ShowsPromptAndItems", func(t *testing.T) {
s := newSelectState(items)
var buf bytes.Buffer
lineCount := renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Select:") {
t.Error("expected prompt in output")
}
if !strings.Contains(output, "item1") {
t.Error("expected item1 in output")
}
if !strings.Contains(output, "first item") {
t.Error("expected description in output")
}
if !strings.Contains(output, "item2") {
t.Error("expected item2 in output")
}
if lineCount != 3 { // 1 prompt + 2 items
t.Errorf("expected 3 lines, got %d", lineCount)
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
}
})
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
manyItems := make([]selectItem, 15)
for i := range manyItems {
manyItems[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(manyItems)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
// 15 items - 10 displayed = 5 more
if !strings.Contains(buf.String(), "5 more") {
t.Error("expected '5 more' indicator")
}
})
}
func TestRenderMultiSelect(t *testing.T) {
items := []selectItem{
{Name: "item1"},
{Name: "item2"},
}
t.Run("ShowsCheckboxes", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "[x]") {
t.Error("expected checked checkbox [x]")
}
if !strings.Contains(output, "[ ]") {
t.Error("expected unchecked checkbox [ ]")
}
})
t.Run("ShowsDefaultMarker", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "(default)") {
t.Error("expected (default) marker for first checked item")
}
})
t.Run("ShowsSelectedCount", func(t *testing.T) {
s := newMultiSelectState(items, []string{"item1", "item2"})
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "2 selected") {
t.Error("expected '2 selected' in output")
}
})
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
s := newMultiSelectState(items, nil)
var buf bytes.Buffer
renderMultiSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "Select at least one") {
t.Error("expected 'Select at least one' helper text")
}
})
}
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
// Edge case tests for selector.go
// TestSelectState_SingleItem verifies that single item list works without crash.
// List with only one item should still work.
func TestSelectState_SingleItem(t *testing.T) {
items := []selectItem{{Name: "only-one"}}
s := newSelectState(items)
// Down should do nothing (already at bottom)
s.handleInput(eventDown, 0)
if s.selected != 0 {
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
}
// Up should do nothing (already at top)
s.handleInput(eventUp, 0)
if s.selected != 0 {
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
}
// Enter should select the only item
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "only-one" || err != nil {
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
}
}
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
// List with exactly maxDisplayedItems items should not scroll.
func TestSelectState_ExactlyMaxItems(t *testing.T) {
items := make([]selectItem, maxDisplayedItems)
for i := range items {
items[i] = selectItem{Name: string(rune('a' + i))}
}
s := newSelectState(items)
// Move to last item
for range maxDisplayedItems - 1 {
s.handleInput(eventDown, 0)
}
if s.selected != maxDisplayedItems-1 {
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
// Should not scroll when exactly at max
if s.scrollOffset != 0 {
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
}
// One more down should do nothing
s.handleInput(eventDown, 0)
if s.selected != maxDisplayedItems-1 {
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
}
}
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
// User typing "model.v1" shouldn't match "modelsv1".
func TestFilterItems_RegexSpecialChars(t *testing.T) {
items := []selectItem{
{Name: "model.v1"},
{Name: "modelsv1"},
{Name: "model-v1"},
}
// Filter with dot should only match literal dot
result := filterItems(items, "model.v1")
if len(result) != 1 {
t.Errorf("expected 1 exact match, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "model.v1" {
t.Errorf("expected 'model.v1', got %s", result[0].Name)
}
// Other regex special chars should be literal too
items2 := []selectItem{
{Name: "test[0]"},
{Name: "test0"},
{Name: "test(1)"},
}
result2 := filterItems(items2, "test[0]")
if len(result2) != 1 || result2[0].Name != "test[0]" {
t.Errorf("expected only 'test[0]', got %v", result2)
}
}
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
// itemIndex uses name as key - duplicates cause collision. This documents
// the current behavior: the last index for a duplicate name is stored
func TestMultiSelectState_DuplicateNames(t *testing.T) {
// Duplicate names - this is an edge case that shouldn't happen in practice
items := []selectItem{
{Name: "duplicate"},
{Name: "duplicate"},
{Name: "unique"},
}
s := newMultiSelectState(items, nil)
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
// When there are duplicates, only the last occurrence's index is stored
if s.itemIndex["duplicate"] != 1 {
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
}
// Toggle item at highlighted=0 (first "duplicate")
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
// So it actually toggles the SECOND duplicate item, not the first
s.toggleItem()
// This documents the potentially surprising behavior:
// We toggled at highlighted=0, but itemIndex lookup returned 1
if !s.checked[1] {
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
}
if s.checked[0] {
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
}
}
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
// Prevents index-out-of-bounds on next keystroke
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newSelectState(items)
s.selected = 2 // Select "cherry"
// Type a filter that removes cherry from results
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
// Selection should reset to 0
if s.selected != 0 {
t.Errorf("expected selected=0 after filter, got %d", s.selected)
}
filtered := s.filtered()
if len(filtered) != 2 {
t.Errorf("expected 2 filtered items, got %d", len(filtered))
}
}
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
// Model names might contain unicode characters
func TestFilterItems_UnicodeCharacters(t *testing.T) {
items := []selectItem{
{Name: "llama-日本語"},
{Name: "模型-chinese"},
{Name: "émoji-🦙"},
{Name: "regular-model"},
}
t.Run("filter japanese", func(t *testing.T) {
result := filterItems(items, "日本")
if len(result) != 1 || result[0].Name != "llama-日本語" {
t.Errorf("expected llama-日本語, got %v", result)
}
})
t.Run("filter chinese", func(t *testing.T) {
result := filterItems(items, "模型")
if len(result) != 1 || result[0].Name != "模型-chinese" {
t.Errorf("expected 模型-chinese, got %v", result)
}
})
t.Run("filter emoji", func(t *testing.T) {
result := filterItems(items, "🦙")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
t.Run("filter accented char", func(t *testing.T) {
result := filterItems(items, "émoji")
if len(result) != 1 || result[0].Name != "émoji-🦙" {
t.Errorf("expected émoji-🦙, got %v", result)
}
})
}
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
items := []selectItem{
{Name: "apple"},
{Name: "banana"},
{Name: "cherry"},
}
s := newMultiSelectState(items, nil)
s.highlighted = 2 // Highlight "cherry"
// Type a filter that removes cherry
s.handleInput(eventChar, 'a')
if s.highlighted != 0 {
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
}
}
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
// Empty list should be handled gracefully.
func TestMultiSelectState_EmptyItems(t *testing.T) {
s := newMultiSelectState([]selectItem{}, nil)
// Toggle should not panic on empty list
s.toggleItem()
if s.selectedCount() != 0 {
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
}
// Render should handle empty list
var buf bytes.Buffer
lineCount := renderMultiSelect(&buf, "Select:", s)
if lineCount == 0 {
t.Error("renderMultiSelect should produce output even for empty list")
}
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' for empty list")
}
}
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
func TestSelectState_RenderWithDescriptions(t *testing.T) {
items := []selectItem{
{Name: "item1", Description: "First item description"},
{Name: "item2", Description: ""},
{Name: "item3", Description: "Third item"},
}
s := newSelectState(items)
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "First item description") {
t.Error("expected description to be rendered")
}
if !strings.Contains(output, "item2") {
t.Error("expected item without description to be rendered")
}
}

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send",
AltPlaceholder: `Use """ to end multi-line input`,
})
if err != nil {
return err
@@ -159,7 +159,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
scanner.Prompt.UseAlt = true
continue
}

View File

@@ -311,10 +311,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,264 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type glm4MoeLiteModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
}
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "glm4moelite"
kv["general.type"] = "model"
kv["glm4moelite.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["glm4moelite.attention.head_count"] = numHeads
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
kv["glm4moelite.attention.value_length"] = p.VHeadDim
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
kv["glm4moelite.embedding_length"] = p.HiddenSize
kv["glm4moelite.expert_count"] = p.ExpertCount
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
kv["glm4moelite.expert_gating_func"] = uint32(2)
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
kv["tokenizer.ggml.pre"] = "glm4"
return kv
}
func (p *glm4MoeLiteModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
if kvFirst {
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
return nil, err
}
if extractK {
// Slice K: [n_head, qk_nope, kv_lora_rank]
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
// Transpose to [n_head, kv_lora_rank, qk_nope]
tt, err = tensor.Transpose(tt, 0, 2, 1)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
} else {
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
numHeads := int(p.NumAttentionHeads)
kvFirst := true
if len(t.Shape()) == 2 {
switch {
case int(t.Shape()[0]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
numHeads = int(t.Shape()[1]) / kvPerHead
}
kvFirst = true
case int(t.Shape()[1]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
numHeads = int(t.Shape()[0]) / kvPerHead
}
kvFirst = false
default:
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
}
}
kTensor := t.Clone()
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
WriterTo: kTensor,
})
vTensor := t.Clone()
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
WriterTo: vTensor,
})
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -1,100 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights: [D, 1, K] -> [D, K]
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
if len(shape) == 3 && shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -40,7 +40,6 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -16,7 +16,6 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
## Conventions
@@ -59,15 +58,6 @@ Advanced parameters (optional):
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
Experimental image generation parameters (for image generation models only):
> [!WARNING]
> These parameters are experimental and may change in future versions.
- `width`: width of the generated image in pixels
- `height`: height of the generated image in pixels
- `steps`: number of diffusion steps
#### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
@@ -1877,55 +1867,3 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Features
### Image Generation (Experimental)
> [!WARNING]
> Image generation is experimental and may change in future versions.
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
#### Example
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "x/z-image-turbo",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
Progress updates during generation:
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:00.000000Z",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:15.000000Z",
"image": "iVBORw0KGgoAAAANSUhEUg...",
"done": true,
"done_reason": "stop",
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -1,423 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='qwen3-coder',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "qwen3-coder",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='qwen3-coder',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```javascript tools.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const message = await anthropic.messages.create({
model: "qwen3-coder",
max_tokens: 1024,
tools: [
{
name: "get_weather",
description: "Get the current weather in a location",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
});
for (const block of message.content) {
if (block.type === "tool_use") {
console.log("Tool:", block.name);
console.log("Input:", block.input);
}
}
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3-coder",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
### Recommended models
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
Download a model before use:
```shell
ollama pull qwen3-coder
```
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
```shell
ollama pull glm-4.7:cloud
```
### Quick setup
```shell
ollama launch claude
```
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Set the environment variables and run Claude Code:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:
```shell
claude --model qwen3-coder
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Ollama supports both local and cloud models.
### Local models
Pull a local model before use:
```shell
ollama pull qwen3-coder
```
Recommended local models:
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
### Cloud models
Cloud models are available immediately without pulling:
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp qwen3-coder claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -275,73 +275,6 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `dimensions`
- [ ] `user`
### `/v1/images/generations` (experimental)
> Note: This endpoint is experimental and may change or be removed in future versions.
Generate images using image generation models.
<CodeGroup dropdown>
```python images.py
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
api_key='ollama', # required but ignored
)
response = client.images.generate(
model='x/z-image-turbo',
prompt='A cute robot learning to paint',
size='1024x1024',
response_format='b64_json',
)
print(response.data[0].b64_json[:50] + '...')
```
```javascript images.js
import OpenAI from "openai";
const openai = new OpenAI({
baseURL: "http://localhost:11434/v1/",
apiKey: "ollama", // required but ignored
});
const response = await openai.images.generate({
model: "x/z-image-turbo",
prompt: "A cute robot learning to paint",
size: "1024x1024",
response_format: "b64_json",
});
console.log(response.data[0].b64_json.slice(0, 50) + "...");
```
```shell images.sh
curl -X POST http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "x/z-image-turbo",
"prompt": "A cute robot learning to paint",
"size": "1024x1024",
"response_format": "b64_json"
}'
```
</CodeGroup>
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `size` (e.g. "1024x1024")
- [x] `response_format` (only `b64_json` supported)
- [ ] `n`
- [ ] `quality`
- [ ] `style`
- [ ] `user`
### `/v1/responses`
> Note: Added in Ollama v0.13.3

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch("what is ollama?");
const results = await client.webSearch({ query: "what is ollama?" });
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch("https://ollama.com");
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -8,47 +8,6 @@ title: CLI Reference
ollama run gemma3
```
### Launch integrations
```
ollama launch
```
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
#### Supported integrations
- **OpenCode** - Open-source coding assistant
- **Claude Code** - Anthropic's agentic coding tool
- **Codex** - OpenAI's coding assistant
- **Droid** - Factory's AI coding agent
#### Examples
Launch an integration interactively:
```
ollama launch
```
Launch a specific integration:
```
ollama launch claude
```
Launch with a specific model:
```
ollama launch claude --model qwen3-coder
```
Configure without launching:
```
ollama launch droid --config
```
#### Multiline input
For multiline input, you can wrap text with `"""`:

View File

@@ -3,6 +3,8 @@ title: Cloud
sidebarTitle: Cloud
---
<Info>Ollama's cloud is currently in preview.</Info>
## Cloud Models
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.

View File

@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
The default context length in Ollama is 4096 tokens.
</Note>
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
## Setting context length
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
### CLI
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
```
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
```
### Check allocated context length and model offloading

View File

@@ -32,9 +32,7 @@
"codeblocks": "system"
},
"contextual": {
"options": [
"copy"
]
"options": ["copy"]
},
"navbar": {
"links": [
@@ -54,9 +52,7 @@
"display": "simple"
},
"examples": {
"languages": [
"curl"
]
"languages": ["curl"]
}
},
"redirects": [
@@ -101,21 +97,16 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/clawdbot",
"/integrations/cline",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
"/integrations/cline",
"/integrations/droid",
"/integrations/goose",
"/integrations/jetbrains",
"/integrations/marimo",
"/integrations/n8n",
"/integrations/onyx",
"/integrations/opencode",
"/integrations/zed",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
"/integrations/n8n",
"/integrations/xcode"
]
},
{
@@ -148,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens.
By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 174 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -9,7 +9,7 @@ sidebarTitle: Welcome
<CardGroup cols={2}>
<Card title="Quickstart" icon="rocket" href="/quickstart">
Get up and running with your first model or integrate Ollama with your favorite tools
Get up and running with your first model
</Card>
<Card
title="Download Ollama"

View File

@@ -1,75 +0,0 @@
---
title: Claude Code
---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
<CodeGroup>
```shell macOS / Linux
curl -fsSL https://claude.ai/install.sh | bash
```
```powershell Windows
irm https://claude.ai/install.ps1 | iex
```
</CodeGroup>
## Usage with Ollama
### Quick setup
```shell
ollama launch claude
```
To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_API_KEY=""
export ANTHROPIC_BASE_URL=http://localhost:11434
```
2. Run Claude Code with an Ollama model:
```shell
claude --model gpt-oss:20b
```
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -1,48 +0,0 @@
---
title: Clawdbot
---
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [Clawdbot](https://clawd.bot/)
```bash
npm install -g clawdbot@latest
```
Then run the onboarding wizard:
```bash
clawdbot onboard --install-daemon
```
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch clawdbot
```
This configures Clawdbot to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch clawdbot --config
```
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -13,21 +13,7 @@ npm install -g @openai/codex
## Usage with Ollama
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
### Quick setup
```
ollama launch codex
```
To configure without launching:
```shell
ollama launch codex --config
```
### Manual setup
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
To use `codex` with Ollama, use the `--oss` flag:

View File

@@ -11,24 +11,10 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch droid
```
To configure without launching:
```shell
ollama launch droid --config
```
### Manual setup
Add a local configuration block to `~/.factory/config.json`:
```json
@@ -87,4 +73,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
}
```
Run `droid` in a new terminal to load the new settings.
Run `droid` in a new terminal to load the new settings.

View File

@@ -1,73 +0,0 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -1,63 +0,0 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,106 +0,0 @@
---
title: OpenCode
---
OpenCode is an open-source AI coding assistant that runs in your terminal.
## Install
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install.sh | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch opencode
```
To configure without launching:
```shell
ollama launch opencode --config
```
### Manual setup
Add a configuration block to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder"
}
}
}
}
}
```
## Cloud Models
`glm-4.7:cloud` is the recommended model for use with OpenCode.
Add the cloud configuration to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama Cloud",
"options": {
"baseURL": "https://ollama.com/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
Run `opencode` in a new terminal to load the new settings.

View File

@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="CLI">
Open a terminal and run the command:
```sh
```
ollama run gemma3
```
</Tab>
<Tab title="cURL">
```sh
```
ollama pull gemma3
```
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="Python">
Start by downloading a model:
```sh
```
ollama pull gemma3
```
Then install Ollama's Python library:
```sh
```
pip install ollama
```
@@ -101,42 +101,3 @@ This quickstart will walk your through running your first model with Ollama. To
</Tabs>
See a full list of available models [here](https://ollama.com/models).
## Coding
For coding use cases, we recommend using the `glm-4.7-flash` model.
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
```sh
ollama pull glm-4.7-flash
```
Alternatively, you can use a more powerful cloud model (with full context length):
```sh
ollama pull glm-4.7:cloud
```
Use `ollama launch` to quickly set up a coding tool with Ollama models:
```sh
ollama launch
```
### Supported integrations
- [OpenCode](/integrations/opencode) - Open-source coding assistant
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
- [Codex](/integrations/codex) - OpenAI's coding assistant
- [Droid](/integrations/droid) - Factory's AI coding agent
### Launch with a specific model
```sh
ollama launch claude --model glm-4.7-flash
```
### Configure without launching
```sh
ollama launch claude --config
```

3
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,3 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -269,8 +269,6 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}
@@ -858,9 +856,7 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"glm4moelite",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",

View File

@@ -73,18 +73,13 @@ func manhattanDistance[V float32 | float64](v1, v2 []V) V {
}
func TestEmbedCosineDistanceCorrelation(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
started := time.Now()
for _, model := range libraryEmbedModels {
t.Run(model, func(t *testing.T) {
if time.Since(started) > softTimeout {
t.Skip("skipping - soft timeout exceeded")
}
testCases := []struct {
a string
b string
@@ -494,19 +489,14 @@ func TestEmbedTruncation(t *testing.T) {
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
started := time.Now()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
if time.Since(started) > softTimeout {
t.Skip("skipping - soft timeout exceeded")
}
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()

View File

@@ -1,148 +0,0 @@
//go:build integration
package integration
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the Ollama API to generate an image and returns the base64 image data
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
var imageBase64 string
err := client.Generate(ctx, &api.GenerateRequest{
Model: model,
Prompt: prompt,
}, func(resp api.GenerateResponse) error {
if resp.Image != "" {
imageBase64 = resp.Image
}
return nil
})
if err != nil {
return "", fmt.Errorf("failed to generate image: %w", err)
}
if imageBase64 == "" {
return "", fmt.Errorf("no image data in response")
}
return imageBase64, nil
}

View File

@@ -21,10 +21,9 @@ func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 90 * time.Second
streamTimeout := 90 * time.Second
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
@@ -48,12 +47,8 @@ func TestAPIToolCalling(t *testing.T) {
"granite3.3": 7,
}
started := time.Now()
for _, model := range libraryToolsModels {
t.Run(model, func(t *testing.T) {
if time.Since(started) > softTimeout {
t.Skip("skipping - soft timeout exceeded")
}
if v, ok := minVRAM[model]; ok {
skipUnderMinVRAM(t, v)
}
@@ -136,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -144,7 +143,6 @@ var (
"granite3.3",
"hermes3",
"internlm2",
"lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -265,7 +263,6 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",

View File

@@ -14,28 +14,25 @@ make -f Makefile.sync apply-patches
### Updating Base Commit
To update to a new base commit:
**Pin to new base commit**
1. **Update FETCH_HEAD** in `Makefile.sync` to the new commit hash.
To change the base commit, update `FETCH_HEAD` in Makefile.sync.
2. **Check for upstreamed patches**: Before applying, review if any patches have been merged upstream. Remove those patches from `./patches/` to avoid conflicts.
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
3. **Apply patches**:
```shell
make -f Makefile.sync apply-patches
```
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
4. **Resolve conflicts** (if any): When `git am` fails on a patch:
- Fix conflicts in `./vendor/`
- Stage the resolved files: `git -C llama/vendor add <file>`
- Continue: `git -C llama/vendor am --continue`
- Re-run: `make -f Makefile.sync apply-patches`
- Repeat until all patches are applied.
```shell
make -f Makefile.sync apply-patches
```
5. **Regenerate patches and sync**:
```shell
make -f Makefile.sync format-patches sync
```
If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied.
Once all patches are applied, commit the changes to the tracking repository.
```shell
make -f Makefile.sync format-patches sync
```
### Generating Patches

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "a5bb8ba4c50257437630c136210396810741bbf7";
char const *LLAMA_COMMIT = "ec98e2002";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
case GGML_SCHED_PRIO_REALTIME: p = -20; break;
}
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
if (!setpriority(PRIO_PROCESS, 0, p)) {
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
return false;
}
@@ -1078,15 +1078,12 @@ struct common_init_result::impl {
impl() = default;
~impl() = default;
// note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
};
common_init_result::common_init_result(common_params & params) :
@@ -1095,9 +1092,9 @@ common_init_result::common_init_result(common_params & params) :
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
@@ -1110,25 +1107,6 @@ common_init_result::common_init_result(common_params & params) :
const llama_vocab * vocab = llama_model_get_vocab(model);
// load and optionally apply lora adapters (must be loaded before context creation)
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
pimpl->model.reset(model);
return;
}
char buf[1024];
la.ptr = lora.get();
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}
// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);
@@ -1163,18 +1141,10 @@ common_init_result::common_init_result(common_params & params) :
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}
// init the backend samplers as part of the context creation
pimpl->samplers.resize(cparams.n_seq_max);
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
}
llama_context * lctx = llama_init_from_model(model, cparams);
@@ -1198,12 +1168,6 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}
void common_init_result::reset_samplers() {
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
}
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
@@ -1279,6 +1243,24 @@ common_init_result_ptr common_init_from_params(common_params & params) {
}
}
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
return res;
}
char buf[1024];
la.ptr = lora.get();
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}
@@ -1319,9 +1301,6 @@ common_init_result_ptr common_init_from_params(common_params & params) {
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
// reset samplers to reset RNG state after warmup to the seeded state
res->reset_samplers();
}
return res;
@@ -1360,12 +1339,14 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.devices = params.devices.data();
}
mparams.n_gpu_layers = params.n_gpu_layers;
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_direct_io = params.use_direct_io;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.use_extra_bufts = !params.no_extra_bufts;

View File

@@ -57,8 +57,6 @@ extern const char * LLAMA_COMMIT;
extern const char * LLAMA_COMPILER;
extern const char * LLAMA_BUILD_TARGET;
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
struct common_control_vector_load_info;
//
@@ -82,8 +80,6 @@ int32_t cpu_get_num_math();
//
enum llama_example {
LLAMA_EXAMPLE_BATCHED,
LLAMA_EXAMPLE_DEBUG,
LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_COMPLETION,
@@ -121,7 +117,6 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12,
};
// dimensionality reduction methods, used by cvector-generator
@@ -169,34 +164,32 @@ enum common_params_sampling_config : uint64_t {
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
@@ -223,8 +216,6 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool backend_sampling = false;
bool has_logit_bias() const {
return !logit_bias.empty();
}
@@ -286,7 +277,6 @@ struct common_params_diffusion {
};
// reasoning API response format (not to be confused as chat template's reasoning format)
// only used by server
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
@@ -339,14 +329,12 @@ struct common_params {
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
// margin per device in bytes for fitting parameters to free memory:
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
@@ -382,11 +370,6 @@ struct common_params {
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
bool save_logits = false; // whether to save logits to files // NOLINT
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
@@ -437,8 +420,7 @@ struct common_params {
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // enable mmap to use filesystem cache
bool use_direct_io = true; // read from disk without buffering for faster model loading
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation
@@ -482,7 +464,6 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
bool cache_prompt = true; // whether to enable prompt caching
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
@@ -494,8 +475,7 @@ struct common_params {
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
std::vector<std::string> api_keys;
@@ -504,11 +484,8 @@ struct common_params {
std::map<std::string, std::string> default_template_kwargs;
// webui configs
bool webui = true;
std::string webui_config_json;
// "advanced" endpoints are disabled by default for better security
bool webui = true;
bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
@@ -708,9 +685,7 @@ struct common_init_result {
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);
void reset_samplers();
std::vector<llama_adapter_lora_ptr> & lora();

View File

@@ -104,9 +104,10 @@ struct ring_buffer {
struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
bool grammar;
ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur;
@@ -120,34 +121,17 @@ struct common_sampler {
}
void set_logits(struct llama_context * ctx, int idx) {
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
const auto * logits = llama_get_logits_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count);
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);
GGML_ASSERT(logits != nullptr);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
cur_p = { cur.data(), cur.size(), -1, false };
@@ -167,59 +151,54 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
mirostat, mirostat_eta, mirostat_tau);
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
bool grammar = false;
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
grammar = true;
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
trigger_patterns.push_back(regex_escape(word));
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
trigger_patterns.push_back(trigger.value);
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
const auto & pattern = trigger.value;
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -233,6 +212,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
@@ -241,12 +224,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size());
samplers.push_back(
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()));
} else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
}
grammar = true;
}
}
@@ -255,9 +241,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
if (params.mirostat == 0) {
bool use_adaptive_p = false; // see below
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
@@ -267,54 +250,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
samplers.push_back(llama_sampler_init_top_k(params.top_k));
samplers.push_back(llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
samplers.push_back(llama_sampler_init_infill(vocab));
samplers.push_back(llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
// the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
// a single token, so we will add `dist` at the end of the chain by default,
// unless the user specifically included `adaptive-p`. we set this flag here
// so we know to add the sampler at the very end.
use_adaptive_p = true;
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
if (use_adaptive_p) {
// only if user explicitly included adaptive-p sampler
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
} else {
// default: sample from distribution
samplers.push_back(llama_sampler_init_dist(params.seed));
}
samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@@ -329,16 +301,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
llama_sampler_chain_add(chain, smpl);
}
if (grmr && params.backend_sampling) {
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
params.backend_sampling = false;
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
@@ -348,45 +314,47 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
void common_sampler_free(struct common_sampler * gsmpl) {
if (!gsmpl) {
return;
if (gsmpl) {
llama_sampler_free(gsmpl->chain);
delete gsmpl;
}
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
}
const auto tm = gsmpl->tm();
if (gsmpl->grmr && accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
if (gsmpl->grammar) {
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
llama_sampler_accept(gsmpl->chain, token);
for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token);
}
gsmpl->prev.push_back(token);
}
void common_sampler_reset(struct common_sampler * gsmpl) {
if (!gsmpl) {
return;
}
gsmpl->reset();
}
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
@@ -439,14 +407,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
if (!gsmpl) {
return nullptr;
}
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@@ -454,61 +418,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
id = llama_get_sampled_token_ith(ctx, idx);
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
return id;
}
}
gsmpl->set_logits(ctx, idx);
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
llama_sampler_apply(grmr, &single_token_data_array);
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
@@ -518,7 +432,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
@@ -526,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true);
@@ -538,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true);
@@ -548,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
@@ -639,7 +553,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
default : return '?';
}
}
@@ -656,7 +569,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
default : return "";
}
}
@@ -673,7 +585,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
// since samplers names are written multiple ways
@@ -689,7 +600,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;
@@ -726,7 +636,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;

View File

@@ -36,8 +36,7 @@ struct common_sampler;
// llama_sampler API overloads
// note: can mutate params in some cases
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl);
@@ -49,7 +48,6 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
// get the underlying llama_sampler_chain
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation:
@@ -59,10 +57,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
// generalized version of common_sampler_sample
//
@@ -80,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

View File

@@ -21,9 +21,7 @@ struct llama_sampler_deleter {
};
struct llama_adapter_lora_deleter {
void operator()(llama_adapter_lora *) {
// llama_adapter_lora_free is deprecated
}
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
};
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;

View File

@@ -286,7 +286,7 @@ extern "C" {
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
@@ -309,7 +309,6 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_direct_io; // use direct io, takes precedence over use_mmap
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@@ -317,11 +316,6 @@ extern "C" {
bool no_alloc; // only load metadata and simulate memory allocations
};
struct llama_sampler_seq_config {
llama_seq_id seq_id;
struct llama_sampler * sampler;
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggml-org/llama.cpp/pull/7544
struct llama_context_params {
@@ -370,12 +364,6 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// [EXPERIMENTAL]
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
};
// model quantization parameters
@@ -479,24 +467,16 @@ extern "C" {
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
enum llama_params_fit_status {
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
};
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
// - returns true if the parameters could be successfully modified to fit device memory
// - this function is NOT thread safe because it modifies the global llama logger state
// - only parameters that have the same value as in llama_default_model_params are modified
// with the exception of the context size which is modified if and only if equal to 0
LLAMA_API enum llama_params_fit_status llama_params_fit(
// returns true if the parameters could be successfully modified to fit device memory
// this function is NOT thread safe because it modifies the global llama logger state
LLAMA_API bool llama_params_fit(
const char * path_model,
struct llama_model_params * mparams,
struct llama_context_params * cparams,
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
size_t * margins, // margins of memory to leave per device in bytes
size_t margin, // margin of memory to leave per device in bytes
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
@@ -537,7 +517,6 @@ extern "C" {
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
@@ -621,8 +600,6 @@ extern "C" {
//
// Load a LoRA adapter from file
// The adapter is valid as long as the associated model is not freed
// All adapters must be loaded before context creation
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
struct llama_model * model,
const char * path_lora);
@@ -647,8 +624,7 @@ extern "C" {
// Manually free a LoRA adapter
// NOTE: loaded adapters will be free when the associated model is deleted
LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter),
"adapters are now freed together with the associated model");
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
@@ -1007,32 +983,6 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
// backend sampling API [EXPERIMENTAL]
// note: use only if the llama_context was created with at least one llama_sampler_seq_config
//
// Get the backend sampled token for the ith token.
// Returns LLAMA_TOKEN_NULL if no token was sampled.
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled probabilites for the ith token
// The index matches llama_get_sampled_token_ith().
// Returns NULL if no probabilites were generated.
LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i);
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled logits for the ith token
// Returns NULL if no logits were sampled.
LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i);
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled candidates (token ids) for the ith token
// These are needed to map probability/logit indices to vocab token ids.
// Returns NULL if no candidates were sampled.
LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i);
LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
//
// Vocab
//
@@ -1204,16 +1154,11 @@ extern "C" {
//
// llama_sampler_free(smpl);
//
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
//
typedef void * llama_sampler_context_t;
struct llama_sampler_data {
struct ggml_tensor * logits;
struct ggml_tensor * probs;
struct ggml_tensor * sampled;
struct ggml_tensor * candidates;
};
// user code can implement the interface below in order to create custom llama_sampler
struct llama_sampler_i {
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
@@ -1223,44 +1168,17 @@ extern "C" {
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
// [EXPERIMENTAL]
// backend sampling interface:
// return true if the backend supports all ops needed by the sampler
// note: call once per sampler
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
// call after .backend_apply()
void (*backend_accept)(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_tensor * selected_token);
// call after .backend_init()
void (*backend_apply)(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data);
// called before graph execution to set inputs for the current ubatch
void (*backend_set_input)(struct llama_sampler * smpl);
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
};
struct llama_sampler {
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};
// [EXPERIMENTAL]
// attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
@@ -1276,15 +1194,7 @@ extern "C" {
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
// return NULL if:
// - the sampler is NULL
// - the sampler is not a llama_sampler_chain
// - the index is out of bounds, unless i == -1
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
// the total number of samplers in the chain
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
@@ -1293,9 +1203,7 @@ extern "C" {
// available samplers:
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop
@@ -1396,33 +1304,6 @@ extern "C" {
const char ** seq_breakers,
size_t num_breakers);
/// adaptive-p: select tokens near a configurable target probability over time.
///
/// the adaptive-p sampler transforms the token probability distribution to favor tokens
/// that fall near a user-configurable probability target.
///
/// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
/// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
/// adapted target probability at each sampling step, thus maintaining the desired target
/// probability over time.
///
/// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
/// in the sampler chain (like mirostat, dist, greedy).
///
/// only mild truncation before this sampler is recommended. we suggest applying min-p
/// before adaptive-p as the only other active sampler in the chain.
///
/// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
/// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
/// @param seed RNG seed
///
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927
///
LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
float target,
float decay,
uint32_t seed);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
@@ -1476,12 +1357,12 @@ extern "C" {
/// @details Build a split GGUF final path for this chunk.
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
// Returns the split_path length.
LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count);
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
/// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
/// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
// Returns the split_prefix length.
LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count);
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
// Print system information
LLAMA_API const char * llama_print_system_info(void);

View File

@@ -411,9 +411,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
}
// register adapter with model
model.loras.insert(&adapter);
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
}
@@ -471,8 +468,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
void llama_adapter_lora_free(llama_adapter_lora *) {
// deprecated: adapters are freed by llama_model's destructor
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}
uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {

View File

@@ -77,10 +77,6 @@ struct llama_adapter_lora {
~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
uint32_t get_n_nodes() const {
return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat
}
};
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View File

@@ -20,7 +20,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
@@ -42,7 +41,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_PLAMO2, "plamo2" },
{ LLM_ARCH_PLAMO3, "plamo3" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
@@ -81,7 +79,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_EXAONE_MOE, "exaone-moe" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
@@ -118,9 +115,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -154,7 +148,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
@@ -212,7 +205,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
@@ -224,7 +216,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
@@ -509,7 +500,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_DECI:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
@@ -791,20 +781,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
};
case LLM_ARCH_MODERN_BERT:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
};
case LLM_ARCH_JINA_BERT_V2:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -954,8 +930,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
@@ -1086,22 +1060,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_FFN_POST_NORM,
};
case LLM_ARCH_PLAMO3:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_POST_NORM,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
};
case LLM_ARCH_CODESHELL:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -1732,38 +1690,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_POST_NORM,
};
case LLM_ARCH_EXAONE_MOE:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_RWKV6:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -2114,7 +2040,6 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM_LFM2,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_DENSE_2_OUT,
};
case LLM_ARCH_LFM2MOE:
return {
@@ -2133,7 +2058,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM_LFM2,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
@@ -2249,49 +2174,11 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
};
case LLM_ARCH_MIMO2:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
};
case LLM_ARCH_GPTJ:
case LLM_ARCH_UNKNOWN:
return {
LLM_TENSOR_TOKEN_EMBD,
};
case LLM_ARCH_MAINCODER:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
};
case LLM_ARCH_SOLAR:
return {
LLM_TENSOR_TOKEN_EMBD,

View File

@@ -24,7 +24,6 @@ enum llm_arch {
LLM_ARCH_STARCODER,
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_MODERN_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
@@ -46,7 +45,6 @@ enum llm_arch {
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_PLAMO2,
LLM_ARCH_PLAMO3,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
@@ -85,7 +83,6 @@ enum llm_arch {
LLM_ARCH_NEMOTRON_H_MOE,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_EXAONE_MOE,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
@@ -122,9 +119,6 @@ enum llm_arch {
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_MIMO2,
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
LLM_ARCH_UNKNOWN,
};
@@ -158,7 +152,6 @@ enum llm_kv {
LLM_KV_VOCAB_SIZE,
LLM_KV_CONTEXT_LENGTH,
LLM_KV_EMBEDDING_LENGTH,
LLM_KV_EMBEDDING_LENGTH_OUT,
LLM_KV_FEATURES_LENGTH,
LLM_KV_BLOCK_COUNT,
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
@@ -216,7 +209,6 @@ enum llm_kv {
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
@@ -228,7 +220,6 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,

View File

@@ -57,7 +57,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
{ "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -75,7 +74,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
{ "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -138,9 +136,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
if (tmpl_contains("<|tool_declare|>")) {
return LLM_CHAT_TEMPLATE_EXAONE_MOE;
}
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
@@ -221,8 +216,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GROK_2;
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
} else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) {
return LLM_CHAT_TEMPLATE_SOLAR_OPEN;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -580,22 +573,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
} else if (role == "user") {
ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
} else if (role == "assistant") {
ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
} else if (role == "tool") {
ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
}
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
@@ -868,14 +845,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[unused9]助手:";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>";
}
if (add_ass) {
ss << "<|begin|>assistant";
}
} else {
// template not supported
return -1;

View File

@@ -36,7 +36,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_EXAONE_4,
LLM_CHAT_TEMPLATE_EXAONE_MOE,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -55,7 +54,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_PANGU_EMBED,
LLM_CHAT_TEMPLATE_SOLAR_OPEN,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View File

File diff suppressed because it is too large Load Diff

View File

@@ -40,14 +40,6 @@ struct llama_context {
~llama_context();
// reserve a new backend scheduler (if needed)
// for example, when:
// - changing loras
// - changing samplers
// - changing attention type
// - etc.
void sched_reserve();
void synchronize();
const llama_model & get_model() const;
@@ -78,18 +70,6 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
float * get_sampled_logits_ith(int32_t idx);
size_t get_sampled_logits_count(int32_t idx);
float * get_sampled_probs_ith(int32_t idx);
size_t get_sampled_probs_count(int32_t idx);
const llama_token * get_sampled_candidates_ith(int32_t idx);
size_t get_sampled_candidates_count(int32_t idx);
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch);
@@ -212,13 +192,10 @@ private:
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
// map the output row index `i` to batch index
int64_t output_resolve_row(int32_t i) const;
//
// graph
//
@@ -236,8 +213,6 @@ public:
ggml_cgraph * graph_reserve(
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
private:
llm_graph_params graph_params(
llm_graph_result * res,
@@ -277,31 +252,6 @@ private:
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
// TODO: simplify
struct sampling_info {
std::map<llama_seq_id, llama_sampler *> samplers;
float * logits = nullptr;
size_t logits_size = 0;
llama_token * sampled = nullptr;
size_t sampled_size = 0;
float * probs = nullptr;
size_t probs_size = 0;
llama_token * candidates = nullptr;
size_t candidates_size = 0;
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;
std::vector<uint32_t> candidates_count;
std::vector<llama_token> token_ids_full_vocab;
};
sampling_info sampling;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
@@ -322,8 +272,6 @@ private:
ggml_backend_sched_ptr sched;
bool sched_need_reserve = true;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;

View File

@@ -30,12 +30,10 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool auto_fa;
bool no_perf;
bool warmup;
bool op_offload;
bool kv_unified;
bool pipeline_parallel;
enum llama_pooling_type pooling_type;

View File

@@ -369,44 +369,6 @@ static void print_rule(
fprintf(file, "\n");
}
//
// Regex utilities
//
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
auto find_start_pos = [](const std::smatch & match) {
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
return start;
};
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
// match against the entire input
std::smatch match;
if (std::regex_match(input, match, regex)) {
return find_start_pos(match);
}
}
// search anywhere
std::smatch match;
if (std::regex_search(input, match, regex)) {
return find_start_pos(match);
}
return std::string::npos;
}
//
// implementation
//
@@ -1359,10 +1321,21 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
auto start = trigger_pattern.find(grammar.trigger_buffer);
if (start != std::string::npos) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false;
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {

View File

@@ -130,8 +130,6 @@ struct llama_grammar_parser {
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
size_t find(const std::string & input) const;
};
struct llama_grammar {

View File

@@ -7,13 +7,11 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include <cassert>
#include <cmath>
#include <cstring>
#include <unordered_set>
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
@@ -23,8 +21,7 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
}
if (ubatch->embd) {
GGML_ASSERT(n_embd == embd->ne[0]);
const int64_t n_embd = embd->ne[0];
const int64_t n_tokens = ubatch->n_tokens;
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
@@ -34,8 +31,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
return res;
}
@@ -65,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
res &= pos->ne[0] == params.ubatch.n_tokens;
return res;
}
@@ -98,9 +95,11 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
int32_t * data = (int32_t *) pos_bucket->data;
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_tokens; ++i) {
data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_tokens; ++i) {
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
}
}
}
}
@@ -323,32 +322,34 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
const uint64_t idst = i1*n_kv;
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
for (int i0 = 0; i0 < n_tokens; ++i0) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
const llama_pos p0 = ubatch->pos[i0];
for (int i0 = 0; i0 < n_tokens; ++i0) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
const llama_pos p0 = ubatch->pos[i0];
// mask different sequences
if (s0 != s1) {
continue;
// mask different sequences
if (s0 != s1) {
continue;
}
// mask future tokens
if (cparams.causal_attn && p0 > p1) {
continue;
}
// apply SWA if any
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
continue;
}
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
// mask future tokens
if (cparams.causal_attn && p0 > p1) {
continue;
}
// apply SWA if any
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
continue;
}
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
};
@@ -407,27 +408,6 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
return res;
}
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -473,19 +453,27 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
float * data = (float *) cross_kq_mask->data;
for (int i = 0; i < n_tokens; ++i) {
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int h = 0; h < 1; ++h) {
for (int i = 0; i < n_tokens; ++i) {
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
}
}
}
data[i*n_enc + j] = f;
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
}
}
for (int i = n_tokens; i < n_tokens; ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
}
}
}
@@ -533,113 +521,6 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
}
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
res &= inp_rs->head == mctx->get_recr()->get_head();
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
return res;
}
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
// set the inputs only for the active samplers in the current ubatch
std::unordered_set<llama_seq_id> active_samplers;
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
if (ubatch->output[i]) {
llama_seq_id seq_id = ubatch->seq_id[i][0];
active_samplers.insert(seq_id);
}
}
for (auto seq_id : active_samplers) {
if (samplers.find(seq_id) == samplers.end()) {
continue;
}
auto & sampler = samplers[seq_id];
if (sampler->iface->backend_set_input) {
sampler->iface->backend_set_input(sampler);
}
}
}
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
if (samplers.size() != params.samplers.size()) {
return false;
}
for (const auto & [seq_id, sampler] : params.samplers) {
if (samplers[seq_id] != sampler) {
return false;
}
}
return true;
}
//
// llm_graph_result
//
@@ -656,15 +537,10 @@ int64_t llm_graph_result::get_max_nodes() const {
}
void llm_graph_result::reset() {
t_inp_tokens = nullptr;
t_inp_embd = nullptr;
t_tokens = nullptr;
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
t_sampled.clear();
t_sampled_probs.clear();
t_sampled_logits.clear();
t_candidates.clear();
params = {};
@@ -689,38 +565,6 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
}
}
void llm_graph_result::set_outputs() {
if (t_logits != nullptr) {
ggml_set_output(t_logits);
}
if (t_embd != nullptr) {
ggml_set_output(t_embd);
}
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_sampled_probs) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_sampled_logits) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_candidates) {
if (t != nullptr) {
ggml_set_output(t);
}
}
}
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
if (!this->params.allow_reuse(params)) {
if (debug > 1) {
@@ -802,7 +646,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
loras (params.loras),
mctx (params.mctx),
cross (params.cross),
samplers (params.samplers),
cb_func (params.cb),
res (params.res),
ctx0 (res->get_ctx()),
@@ -1361,29 +1204,17 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// input embeddings with optional lora
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
const int64_t n_embd_inp = hparams.n_embd_inp();
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd = hparams.n_embd_inp();
assert(n_embd_inp >= n_embd);
auto inp = std::make_unique<llm_graph_input_embd>();
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
ggml_tensor * cur = nullptr;
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
res->t_inp_tokens = inp->tokens;
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
cb(inp->embd, "inp_embd", -1);
ggml_set_input(inp->embd);
// select one of the 2 inputs, based on the batch contents
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
std::array<ggml_tensor *, 2> inps;
// token embeddings path (ubatch.token != nullptr)
{
auto & cur = inps[0];
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
@@ -1404,43 +1235,22 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
cur = ggml_add(ctx0, cur, inpL_delta);
}
if (n_embd_inp != n_embd) {
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
}
}
// vector embeddings path (ubatch.embd != nullptr)
{
auto & cur = inps[1];
} else {
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
ggml_set_input(inp->embd);
cur = inp->embd;
}
assert(ggml_are_same_shape (inps[0], inps[1]));
assert(ggml_are_same_stride(inps[0], inps[1]));
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
if (n_embd_inp != n_embd) {
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
}
res->t_inp_embd = cur;
// For Granite architecture
if (hparams.f_embedding_scale != 0.0f) {
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
}
cb(cur, "embd", -1);
cb(cur, "inp_embd", -1);
res->add_input(std::move(inp));
// make sure the produced embeddings are immediately materialized in the ggml graph
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
ggml_build_forward_expand(gf, cur);
return cur;
}
@@ -1532,7 +1342,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
//}
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
ggml_set_input(cur);
@@ -1630,11 +1440,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
}
ggml_flash_attn_ext_add_sinks(cur, sinks);
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
@@ -1844,11 +1649,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla, // TODO: remove
ggml_tensor * v_mla,
float kq_scale,
int il) const {
GGML_ASSERT(v_mla == nullptr);
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@@ -1891,93 +1694,6 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
return inp;
}
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);
const auto * mctx_cur = inp->mctx;
// store to KV cache
{
const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
@@ -2118,10 +1834,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@@ -2134,10 +1848,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@@ -2273,62 +1985,17 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
// build iswa attention input
const auto * attn_ctx = mctx_cur->get_attn();
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = attn_ctx->get_base()->get_n_kv();
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
{
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
}
void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {
if (!cparams.embeddings || !(dense_2 || dense_3)) {
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
return;
}
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
if (dense_2) {
cur = ggml_mul_mat(ctx0, dense_2, cur);
}
if (dense_3) {
cur = ggml_mul_mat(ctx0, dense_3, cur);
}
cur = ggml_mul_mat(ctx0, dense_2, cur);
cur = ggml_mul_mat(ctx0, dense_3, cur);
cb(cur, "result_embd_pooled", -1);
res->t_embd_pooled = cur;
ggml_build_forward_expand(gf, cur);
@@ -2419,87 +2086,6 @@ void llm_graph_context::build_pooling(
ggml_build_forward_expand(gf, cur);
}
void llm_graph_context::build_sampling() const {
if (samplers.empty() || !res->t_logits) {
return;
}
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
std::map<llama_seq_id, int32_t> seq_to_logit_row;
int32_t logit_row_idx = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (ubatch.output[i]) {
llama_seq_id seq_id = ubatch.seq_id[i][0];
seq_to_logit_row[seq_id] = logit_row_idx;
logit_row_idx++;
}
}
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
// add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
// this is important in order to minimize graph reallocations
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) {
const auto it = seq_to_logit_row.find(seq_id);
// inactive samplers always work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
struct llama_sampler_data data = {
/*.logits =*/ logits_seq,
/*.probs =*/ nullptr,
/*.sampled =*/ nullptr,
/*.candidates =*/ nullptr,
};
assert(sampler->iface->backend_apply);
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
}
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
}
if (data.logits != nullptr) {
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, data.logits);
}
if (data.candidates != nullptr) {
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
}
}
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
/*
for (const auto & [seq_id, sampler] : samplers) {
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
ggml_tensor * selected_token = it->second;
if (selected_token != nullptr) {
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
}
}
}
*/
}
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;

View File

@@ -10,7 +10,6 @@
#include <memory>
#include <set>
#include <functional>
#include <map>
struct ggml_cgraph;
struct ggml_context;
@@ -24,7 +23,6 @@ class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
class llama_memory_hybrid_iswa_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -106,7 +104,7 @@ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
llm_graph_input_embd() = default;
virtual ~llm_graph_input_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -115,8 +113,6 @@ public:
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
const int64_t n_embd = 0;
};
class llm_graph_input_pos : public llm_graph_input_i {
@@ -317,39 +313,6 @@ public:
const llama_kv_cache_context * mctx;
};
// V-less input for the KV cache
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
class llm_graph_input_attn_k : public llm_graph_input_i {
public:
llm_graph_input_attn_k(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_k() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_context * mctx;
};
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_iswa(
@@ -433,46 +396,6 @@ public:
const llama_memory_hybrid_context * mctx;
};
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid_iswa(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_iswa_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_cparams cparams;
const llama_memory_hybrid_iswa_context * mctx;
};
class llm_graph_input_sampling : public llm_graph_input_i {
public:
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
samplers(std::move(samplers)) { }
virtual ~llm_graph_input_sampling() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::map<llama_seq_id, llama_sampler *> samplers;
};
//
// llm_graph_result
//
@@ -506,23 +429,6 @@ struct llm_graph_params {
const llama_memory_context_i * mctx;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
static bool samplers_equal(
const std::map<llama_seq_id, llama_sampler *> & lhs,
const std::map<llama_seq_id, llama_sampler *> & rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto & [seq_id, sampler] : lhs) {
auto it = rhs.find(seq_id);
if (it == rhs.end() || it->second != sampler) {
return false;
}
}
return true;
}
uint32_t n_outputs;
llm_graph_cb cb;
@@ -562,36 +468,15 @@ struct llm_graph_params {
return false;
}
if (n_outputs != other.n_outputs) {
return false;
}
if (!samplers_equal(samplers, other.samplers)) {
return false;
}
if (samplers.size() > 0) {
if (!ubatch.data || !other.ubatch.data) {
return false;
}
// check that the outputs are the same for all samplers
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
if (ubatch.output[i] != other.ubatch.output[i] ||
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
return false;
}
}
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross;
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross &&
n_outputs == other.n_outputs;
}
};
@@ -601,7 +486,7 @@ public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
ggml_tensor * get_tokens() const { return t_tokens; }
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
@@ -614,7 +499,6 @@ public:
void reset();
void set_inputs(const llama_ubatch * ubatch);
void set_outputs();
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters
@@ -628,17 +512,11 @@ public:
void set_params(const llm_graph_params & params);
// important graph nodes
ggml_tensor * t_inp_tokens = nullptr;
ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
std::map<llama_seq_id, ggml_tensor*> t_sampled;
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
std::vector<llm_graph_input_ptr> inputs;
ggml_context_ptr ctx_compute;
@@ -714,8 +592,6 @@ struct llm_graph_context {
const llama_memory_context_i * mctx;
const llama_cross * cross;
std::map<llama_seq_id, llama_sampler *> samplers;
const llm_graph_cb & cb_func;
llm_graph_result * res;
@@ -866,21 +742,6 @@ struct llm_graph_context {
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
float kq_scale,
int il) const;
llm_graph_input_attn_k * build_attn_inp_k() const;
ggml_tensor * build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
@@ -961,8 +822,6 @@ struct llm_graph_context {
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
//
// pooling
//
@@ -973,12 +832,6 @@ struct llm_graph_context {
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
//
// sampling (backend sampling)
//
void build_sampling() const;
//
// dense (out)
//

View File

@@ -72,10 +72,6 @@ uint32_t llama_hparams::n_embd_inp() const {
return n_embd_inp;
}
uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);
@@ -183,21 +179,6 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}
bool llama_hparams::is_mla() const {
assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
(n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));
return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
}
uint32_t llama_hparams::n_embd_head_k_mla() const {
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
}
uint32_t llama_hparams::n_embd_head_v_mla() const {
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
}
bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
@@ -223,6 +204,42 @@ uint32_t llama_hparams::n_layer_kv() const {
return res;
}
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
bool llama_hparams::use_mrope() const {
return rope_sections[0] > 0 && rope_sections[1] > 0;
}

View File

@@ -3,7 +3,6 @@
#include "llama.h"
#include <array>
#include <cassert>
// bump if necessary
#define LLAMA_MAX_LAYERS 512
@@ -53,8 +52,8 @@ struct llama_hparams {
uint32_t n_rel_attn_bkts = 0;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla_impl = 0;
uint32_t n_embd_head_v_mla_impl = 0;
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;
// for WavTokenizer
struct llama_hparams_posnet posnet;
@@ -108,9 +107,9 @@ struct llama_hparams {
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_base_train_swa = 10000.0f;
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa = 1.0f;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
@@ -126,11 +125,10 @@ struct llama_hparams {
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == 1, then layer il is SWA
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
// note: using uint32_t type for compatibility reason
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;
@@ -165,9 +163,6 @@ struct llama_hparams {
// for Classifiers
uint32_t n_cls_out = 1;
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;
// llama4 smallthinker
uint32_t n_moe_layer_step = 0;
uint32_t n_no_rope_layer_step = 4;
@@ -240,9 +235,6 @@ struct llama_hparams {
// dimension of main + auxiliary input embeddings
uint32_t n_embd_inp() const;
// dimension of output embeddings
uint32_t n_embd_out() const;
// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
@@ -274,57 +266,15 @@ struct llama_hparams {
bool is_swa(uint32_t il) const;
// note: currently only support if either all or none of the layers are MLA
bool is_mla() const;
uint32_t n_embd_head_k_mla() const;
uint32_t n_embd_head_v_mla() const;
bool has_kv(uint32_t il) const;
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
// note that this function uses different SWA parameters from those in the hparams
// note: inlined on purpose for performance reasons
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
bool use_mrope() const;
};

View File

@@ -97,8 +97,6 @@ llama_kv_cache::llama_kv_cache(
__func__, hparams.n_embd_v_gqa_max());
}
const bool is_mla = hparams.is_mla();
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -132,21 +130,18 @@ llama_kv_cache::llama_kv_cache(
throw std::runtime_error("failed to create ggml context for kv cache");
}
const bool has_k = true;
const bool has_v = !is_mla;
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
has_k && ggml_format_name(k, "cache_k_l%d", il);
has_v && ggml_format_name(v, "cache_v_l%d", il);
ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il);
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
for (uint32_t s = 0; s < n_stream; ++s) {
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
}
map_layer_ids[il] = layers.size();
@@ -652,10 +647,7 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
const auto & layer = layers[il];
ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
if (layer.v_stream[ssrc]) {
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
}
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
}
}
}
@@ -860,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
@@ -1245,197 +1237,6 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
}
}
struct args_set_input_kq_mask {
const llama_hparams & hparams;
const llama_ubatch * ubatch;
const std::vector<llama_kv_cells> & v_cells;
const std::vector<uint32_t> & seq_to_stream;
uint32_t n_swa;
llama_swa_type swa_type;
int64_t n_kv;
int64_t n_stream;
int64_t n_tps;
};
template<bool causal, bool swa, bool is_2d, bool alibi>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
//const auto & hparams = args.hparams;
const auto & ubatch = args.ubatch;
const auto & v_cells = args.v_cells;
const auto & seq_to_stream = args.seq_to_stream;
const uint32_t n_swa = args.n_swa;
const llama_swa_type swa_type = args.swa_type;
const int64_t n_kv = args.n_kv;
const int64_t n_stream = args.n_stream;
const int64_t n_tps = args.n_tps;
// the min position in the batch for each sequence
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
const llama_seq_id seq_id = ubatch->seq_id[i][0];
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
}
for (uint32_t s = 0; s < n_stream; ++s) {
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
const llama_seq_id seq_id = ubatch->seq_id[i][0];
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
llama_pos p0 = -1;
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*i;
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
// the only cells that could change are the ones that are with similar positions as the
// ones in the batch (i.e. due to causal masking, SWA, etc.)
// keep track of those cells and shortcut the loop to save time
// note: this optimization is not compatible with Alibi position encoding
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
bool prev = false;
auto & idxs = seq_idxs[seq_id];
if (!alibi) {
if (seq_srct.find(seq_id) != seq_srct.end()) {
const uint32_t srct = seq_srct[seq_id];
const uint64_t idst_prev = n_kv*srct;
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
prev = true;
} else {
idxs.clear();
idxs.reserve(ubatch->n_tokens + n_swa + 32);
seq_srct[seq_id] = i;
}
}
for (uint32_t jj = 0; jj < n_kv; ++jj) {
uint32_t j = jj;
// we have an exiting mask for this sequence -> update just seq_idxs
if (!alibi) {
if (prev) {
if (jj >= idxs.size()) {
break;
}
j = idxs[jj];
}
}
if (cells.is_empty(j)) {
goto skip;
}
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
goto skip;
}
p0 = cells.pos_get(j);
if (!alibi) {
if (!prev) {
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
idxs.push_back(j);
}
}
}
if (causal) {
// mask future tokens
if (p0 > p1) {
goto skip;
}
// M-RoPE causal mask
if (is_2d) {
if (p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
goto skip;
}
}
}
}
// apply SWA if any
if (swa) {
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
goto skip;
}
}
if (alibi) {
data[idst + j] = -std::abs(p0 - p1);
} else {
data[idst + j] = 0.0f;
}
continue;
skip:
data[idst + j] = -INFINITY;
}
}
}
}
template<bool causal, bool swa, bool is_2d>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool alibi = args.hparams.use_alibi;
if (alibi) {
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
}
}
template<bool causal, bool swa>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool is_2d = args.ubatch->is_pos_2d();
if (is_2d) {
set_input_kq_mask_impl<causal, swa, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, false>(args, data);
}
}
template<bool causal>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
if (swa) {
set_input_kq_mask_impl<causal, true> (args, data);
} else {
set_input_kq_mask_impl<causal, false>(args, data);
}
}
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens;
@@ -1450,29 +1251,74 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
// n_tps == n_tokens_per_stream
const int64_t n_tps = n_tokens/n_stream;
//const int64_t t_start = ggml_time_us();
std::fill(data, data + ggml_nelements(dst), -INFINITY);
const args_set_input_kq_mask args = {
/*.hparams =*/ hparams,
/*.ubatch =*/ ubatch,
/*.v_cells =*/ v_cells,
/*.seq_to_stream =*/ seq_to_stream,
/*.n_swa =*/ n_swa,
/*.swa_type =*/ swa_type,
/*.n_kv =*/ n_kv,
/*.n_stream =*/ n_stream,
/*.n_tps =*/ n_tps,
};
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
// TODO: optimize this section
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_stream; ++s) {
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
if (causal_attn) {
set_input_kq_mask_impl<true> (args, data);
} else {
set_input_kq_mask_impl<false>(args, data);
const llama_seq_id seq_id = ubatch->seq_id[i][0];
const auto & cells = v_cells[seq_to_stream[seq_id]];
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
for (uint32_t j = 0; j < n_kv; ++j) {
if (cells.is_empty(j)) {
continue;
}
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
continue;
}
const llama_pos p0 = cells.pos_get(j);
// mask future tokens
if (causal_attn && p0 > p1) {
continue;
}
// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
}
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
}
}
//const int64_t t_end = ggml_time_us();
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
}
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
@@ -1524,7 +1370,7 @@ size_t llama_kv_cache::size_v_bytes() const {
size_t size_v_bytes = 0;
for (const auto & layer : layers) {
size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
size_v_bytes += ggml_nbytes(layer.v);
}
return size_v_bytes;
@@ -1602,10 +1448,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto & n_rot = hparams.n_rot;
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@@ -1626,10 +1468,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
n_rot, n_head_kv, get_size()*n_stream,
n_embd_head_k, n_head_kv, get_size()*n_stream,
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
0);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
@@ -1641,6 +1483,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
return gf;
}
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
@@ -1806,9 +1652,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
auto * v = layer.v_stream[cr.strm];
if (!v) {
continue;
}
// Write value type
const int32_t v_type_i = (int32_t) v->type;
@@ -1835,9 +1678,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
auto * v = layer.v_stream[cr.strm];
if (!v) {
continue;
}
// Write value type
const int32_t v_type_i = (int32_t) v->type;
@@ -2041,9 +1881,6 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
auto * v = layer.v_stream[strm];
if (!v) {
continue;
}
// Read type of value
int32_t v_type_i_ref;
@@ -2085,9 +1922,6 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
auto * v = layer.v_stream[strm];
if (!v) {
continue;
}
// Read type of value
int32_t v_type_i_ref;

View File

@@ -257,6 +257,8 @@ private:
size_t size_k_bytes() const;
size_t size_v_bytes() const;
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
@@ -303,7 +305,7 @@ public:
bool do_shift,
stream_copy_info sc_info);
// used to create a batch processing context from a batch
// used to create a batch procesing context from a batch
llama_kv_cache_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,

View File

@@ -1,275 +0,0 @@
#include "llama-memory-hybrid-iswa.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-context.h"
//
// llama_memory_hybrid_iswa
//
llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool swa_full,
uint32_t kv_size,
uint32_t n_ubatch,
uint32_t n_pad,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache_iswa(
model,
type_k,
type_v,
v_trans,
offload,
swa_full,
unified,
kv_size,
n_seq_max,
n_ubatch,
n_pad,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
type_r,
type_s,
offload,
rs_size,
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
// follow the recurrent pattern for creating the ubatch splits
std::vector<llama_ubatch> ubatches;
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
// prepare the attention cache (iswa version returns both base and swa slot infos)
auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
if (sinfos_base.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
if (sinfos_swa.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_hybrid_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while(false);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
return std::make_unique<llama_memory_hybrid_iswa_context>(this);
}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
}
bool llama_memory_hybrid_iswa::get_can_shift() const {
// Shifting is trivially supported for recurrent
return mem_attn->get_can_shift();
}
void llama_memory_hybrid_iswa::clear(bool data) {
mem_attn->clear(data);
mem_recr->clear(data);
}
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// Try removing from the recurrent cache first since it may fail. If it does
// fail, the cache will not have been mutated.
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
return false;
}
return mem_attn->seq_rm(seq_id, p0, p1);
}
void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
mem_attn->seq_keep(seq_id);
mem_recr->seq_keep(seq_id);
}
void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
mem_attn->seq_add(seq_id, p0, p1, shift);
mem_recr->seq_add(seq_id, p0, p1, shift);
}
void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
mem_attn->seq_div(seq_id, p0, p1, d);
mem_recr->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the min of the total cache is the max of the two caches' min values
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
}
llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
// the max of the total cache is the min of the two caches' max values
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
for (const auto & buft_size : mem_recr->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
mem_attn->state_write(io, seq_id, flags);
mem_recr->state_write(io, seq_id, flags);
}
void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
mem_attn->state_read(io, seq_id, flags);
mem_recr->state_read(io, seq_id, flags);
}
llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
return mem_attn.get();
}
llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
return mem_recr.get();
}
//
// llama_memory_hybrid_iswa_context
//
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
ctx_attn(mem->get_mem_attn()->init_full()),
ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
llama_context * lctx,
bool optimize) :
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
bool llama_memory_hybrid_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_attn->next();
ctx_recr->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_memory_hybrid_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & ctx_attn->apply();
res = res & ctx_recr->apply();
return res;
}
llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
}
const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
}

View File

@@ -1,140 +0,0 @@
#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
#include <memory>
#include <vector>
//
// llama_memory_hybrid_iswa
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
// support models where each layer may be either attention-based (with SWA support) or recurrent
class llama_memory_hybrid_iswa : public llama_memory_i {
public:
llama_memory_hybrid_iswa(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool swa_full,
uint32_t kv_size,
uint32_t n_ubatch,
uint32_t n_pad,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid_iswa() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid_iswa specific API
//
llama_kv_cache_iswa * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
// init full
explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
// init update
explicit llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
llama_context * lctx,
bool optimize);
// init success
llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_iswa_context() = default;
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_hybrid_iswa_context
//
const llama_kv_cache_iswa_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_attn;
const llama_memory_context_ptr ctx_recr;
const llama_memory_status status;
};

Some files were not shown because too many files have changed in this diff Show More