Compare commits
41 Commits
mlx-gpu-cd
...
v0.14.3-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68e00c7c36 | ||
|
|
4f138a1749 | ||
|
|
03bf241c33 | ||
|
|
a887406c24 | ||
|
|
d51e95ba7e | ||
|
|
3d01f2aa34 | ||
|
|
634c416645 | ||
|
|
57de86cc61 | ||
|
|
12719b6e87 | ||
|
|
a077d996e3 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 | ||
|
|
80f3f1bc25 | ||
|
|
b1a0db547b | ||
|
|
75d7b5f926 | ||
|
|
349d814814 | ||
|
|
c8743031e0 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 | ||
|
|
02a2401596 | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
if(APPLE)
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
|
||||
14
Dockerfile
@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
RUN yum install -y yum-utils epel-release \
|
||||
&& dnf install -y clang ccache \
|
||||
&& dnf install -y clang ccache git \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||
ENV CC=clang CXX=clang++
|
||||
|
||||
@@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
@@ -156,11 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -169,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
1
MLX_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.4.1
|
||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
|
||||
```shell
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
|
||||
28
api/types.go
@@ -127,6 +127,20 @@ type GenerateRequest struct {
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Width is the width of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps for image generation.
|
||||
// Only used for image generation models.
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -860,6 +874,20 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Image contains a base64-encoded generated image.
|
||||
// Only present for image generation models.
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Completed is the number of completed steps in image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Total is the total number of steps for image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Total int64 `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
135
cmd/cmd.go
@@ -46,8 +46,9 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -93,14 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Validate model name early to fail fast
|
||||
modelName := args[0]
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
if create.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -133,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -464,14 +538,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -533,9 +599,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImage) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -565,7 +640,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -671,7 +746,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -820,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
// Unload the model if it's running before deletion
|
||||
if err := loadOrUnloadModel(cmd, &runOptions{
|
||||
Model: args[0],
|
||||
Model: arg,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -837,11 +916,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1741,15 +1815,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1786,6 +1867,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
@@ -1903,6 +1985,7 @@ func NewCLI() *cobra.Command {
|
||||
} {
|
||||
switch cmd {
|
||||
case runCmd:
|
||||
imagegen.AppendFlagsDocs(cmd)
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||
case serveCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
|
||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -311,6 +311,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
150
convert/convert_glm4moelite.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type glm4MoeLiteModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glm4moelite"
|
||||
kv["general.type"] = "model"
|
||||
kv["glm4moelite.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["glm4moelite.attention.head_count"] = numHeads
|
||||
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
|
||||
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["glm4moelite.attention.value_length"] = p.VHeadDim
|
||||
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["glm4moelite.embedding_length"] = p.HiddenSize
|
||||
kv["glm4moelite.expert_count"] = p.ExpertCount
|
||||
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
kv["glm4moelite.expert_gating_func"] = uint32(2)
|
||||
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
|
||||
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
62
docs/api.md
@@ -16,6 +16,7 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -58,6 +59,15 @@ Advanced parameters (optional):
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
Experimental image generation parameters (for image generation models only):
|
||||
|
||||
> [!WARNING]
|
||||
> These parameters are experimental and may change in future versions.
|
||||
|
||||
- `width`: width of the generated image in pixels
|
||||
- `height`: height of the generated image in pixels
|
||||
- `steps`: number of diffusion steps
|
||||
|
||||
#### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
|
||||
@@ -1867,3 +1877,55 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Features
|
||||
|
||||
### Image Generation (Experimental)
|
||||
|
||||
> [!WARNING]
|
||||
> Image generation is experimental and may change in future versions.
|
||||
|
||||
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
|
||||
|
||||
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
|
||||
|
||||
#### Example
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
Progress updates during generation:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"image": "iVBORw0KGgoAAAANSUhEUg...",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -275,6 +275,73 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
- [x] `dimensions`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/images/generations` (experimental)
|
||||
|
||||
> Note: This endpoint is experimental and may change or be removed in future versions.
|
||||
|
||||
Generate images using image generation models.
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python images.py
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://localhost:11434/v1/',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model='x/z-image-turbo',
|
||||
prompt='A cute robot learning to paint',
|
||||
size='1024x1024',
|
||||
response_format='b64_json',
|
||||
)
|
||||
print(response.data[0].b64_json[:50] + '...')
|
||||
```
|
||||
|
||||
```javascript images.js
|
||||
import OpenAI from "openai";
|
||||
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:11434/v1/",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const response = await openai.images.generate({
|
||||
model: "x/z-image-turbo",
|
||||
prompt: "A cute robot learning to paint",
|
||||
size: "1024x1024",
|
||||
response_format: "b64_json",
|
||||
});
|
||||
|
||||
console.log(response.data[0].b64_json.slice(0, 50) + "...");
|
||||
```
|
||||
|
||||
```shell images.sh
|
||||
curl -X POST http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "A cute robot learning to paint",
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json"
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `prompt`
|
||||
- [x] `size` (e.g. "1024x1024")
|
||||
- [x] `response_format` (only `b64_json` supported)
|
||||
- [ ] `n`
|
||||
- [ ] `quality`
|
||||
- [ ] `style`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode"
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
BIN
docs/images/marimo-add-model.png
Normal file
|
After Width: | Height: | Size: 174 KiB |
BIN
docs/images/marimo-chat.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
docs/images/marimo-code-completion.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
docs/images/marimo-models.png
Normal file
|
After Width: | Height: | Size: 178 KiB |
BIN
docs/images/marimo-settings.png
Normal file
|
After Width: | Height: | Size: 186 KiB |
BIN
docs/images/onyx-login.png
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
docs/images/onyx-ollama-form.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
BIN
docs/images/onyx-ollama-llm.png
Normal file
|
After Width: | Height: | Size: 300 KiB |
BIN
docs/images/onyx-query.png
Normal file
|
After Width: | Height: | Size: 211 KiB |
@@ -2,6 +2,12 @@
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
|
||||

|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
@@ -25,22 +31,24 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
@@ -67,3 +75,4 @@ claude --model glm-4.7:cloud
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
73
docs/integrations/marimo.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
63
docs/integrations/onyx.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
@@ -269,6 +269,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -856,6 +857,7 @@ func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
|
||||
18
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -30,8 +30,8 @@ require (
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.31.0
|
||||
golang.org/x/tools v0.40.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -81,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
|
||||
@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
@@ -541,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
if err := json.Unmarshal(data, &generateResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Only write response when done with image
|
||||
if generateResponse.Done && generateResponse.Image != "" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing prompt",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing model",
|
||||
body: `{
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Test that ImageWriter transforms GenerateResponse to OpenAI format
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: time.Unix(1234567890, 0).UTC(),
|
||||
Done: true,
|
||||
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
body := `{"model": "test-model", "prompt": "test"}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var imageResp openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if imageResp.Created != 1234567890 {
|
||||
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
|
||||
}
|
||||
|
||||
if len(imageResp.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
|
||||
}
|
||||
|
||||
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
304
model/models/glm4moelite/model.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
|
||||
kvLoraRank,
|
||||
qkNopeHeadDim,
|
||||
qkRopeHeadDim,
|
||||
kqNopeHeadDim,
|
||||
qkHeadDim int
|
||||
qLoraRank int
|
||||
vHeadDim int
|
||||
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads int
|
||||
|
||||
eps,
|
||||
ropeBase float32
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Q *nn.Linear `gguf:"attn_q"`
|
||||
|
||||
QA *nn.Linear `gguf:"attn_q_a"`
|
||||
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||
QB *nn.Linear `gguf:"attn_q_b"`
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
compressedKV.Stride(1), 1,
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
experts = experts.Mul(ctx, topKWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residuals := hiddenStates
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
scores := routerLogits.Sigmoid(ctx)
|
||||
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||
|
||||
if opts.normTopKProb {
|
||||
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
if i < firstDenseLayerIndex {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{}
|
||||
}
|
||||
}
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "glm4":
|
||||
pre = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glm4moelite", New)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
|
||||
410
model/parsers/glm46.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type glm46ParserState int
|
||||
|
||||
const (
|
||||
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
|
||||
glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
glm46ParserState_CollectingThinking
|
||||
glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
glm46ParserState_CollectingContent
|
||||
glm46ParserState_ToolStartedEatingWhitespace
|
||||
glm46ParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
glm46ThinkingOpenTag = "<think>"
|
||||
glm46ThinkingCloseTag = "</think>"
|
||||
glm46ToolOpenTag = "<tool_call>"
|
||||
glm46ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46Event interface {
|
||||
isGLM46Event()
|
||||
}
|
||||
|
||||
type glm46EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventContent) isGLM46Event() {}
|
||||
|
||||
type glm46EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (glm46EventRawToolCall) isGLM46Event() {}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseGLM46ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46ParserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
|
||||
|
||||
case glm46ParserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ThinkingCloseTag) {
|
||||
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
|
||||
|
||||
case glm46ParserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
|
||||
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
|
||||
|
||||
case glm46ParserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ToolCloseTag) {
|
||||
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm46 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, glm46EventRawToolCall{raw: toolContent})
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
|
||||
type GLMToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeGLM46Content(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
// We need to escape text between tags, but not the tags themselves
|
||||
escaped := escapeGLM46Content(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
862
model/parsers/glm46_test.go
Normal file
@@ -0,0 +1,862 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46ParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []glm46Event
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "leading whitespace before think tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n\t ",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "<think>thinking</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag with whitespace inside",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think> \n thinking content \n </think>regular content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
glm46EventContent{content: "regular content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with leading whitespace after opening tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call> \n test \n </tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "simple thinking then content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>I am thinking</think>Now I respond",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "I am thinking"},
|
||||
glm46EventContent{content: "Now I respond"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streamed thinking content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>hello",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
|
||||
},
|
||||
{
|
||||
input: " world",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>Let me call a tool</think>here is text<tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "Let me call a tool"},
|
||||
glm46EventContent{content: "here is text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with content after",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
|
||||
},
|
||||
{
|
||||
input: "ink>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nk>content</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split tool open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think>content<tool",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "_call>inside",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "inside"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "ought more",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nking is fun",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " <thinking is fun"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content\n<tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "\n<tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>content</tool",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "content</tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "first"},
|
||||
glm46EventContent{content: "between"},
|
||||
glm46EventRawToolCall{raw: "second"},
|
||||
glm46EventContent{content: "end"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - direct to content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "just content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "just content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - skip to content then tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "Here's the answer:<tool_call>test</tool_call>done",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "Here's the answer:"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "done"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - whitespace preserved when no tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n content with leading whitespace",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " \n content with leading whitespace"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after think close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after tool_call close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>test</tool_call> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace (single chunk)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace with newlines",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking\n\n ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content trailing whitespace emitted when more content arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "more thinking",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: " more thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace before partial close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking </th",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ink>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := GLM46Parser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
|
||||
// document order when collecting multiple elements with the same tag name into slices.
|
||||
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
|
||||
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
xml string
|
||||
wantKeys []string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "alternating keys and values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>A</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>B</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>C</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"first", "second", "third"},
|
||||
wantValues: []string{"A", "B", "C"},
|
||||
},
|
||||
{
|
||||
name: "all keys then all values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_key>key2</arg_key>
|
||||
<arg_key>key3</arg_key>
|
||||
<arg_value>val1</arg_value>
|
||||
<arg_value>val2</arg_value>
|
||||
<arg_value>val3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"key1", "key2", "key3"},
|
||||
wantValues: []string{"val1", "val2", "val3"},
|
||||
},
|
||||
{
|
||||
name: "mixed grouping",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>a</arg_key>
|
||||
<arg_value>1</arg_value>
|
||||
<arg_key>b</arg_key>
|
||||
<arg_key>c</arg_key>
|
||||
<arg_value>2</arg_value>
|
||||
<arg_value>3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"a", "b", "c"},
|
||||
wantValues: []string{"1", "2", "3"},
|
||||
},
|
||||
{
|
||||
name: "reverse order - all values then all keys",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_value>X</arg_value>
|
||||
<arg_value>Y</arg_value>
|
||||
<arg_value>Z</arg_value>
|
||||
<arg_key>x</arg_key>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_key>z</arg_key>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"x", "y", "z"},
|
||||
wantValues: []string{"X", "Y", "Z"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var parsed GLMToolCallXML
|
||||
err := xml.Unmarshal([]byte(tc.xml), &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal XML: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
|
||||
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
|
||||
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM46ToolCallParsing(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-current-weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>New York, NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `calculate
|
||||
<arg_key>x</arg_key>
|
||||
<arg_value>3.14</arg_value>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>["a", "b", "c"]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "function name with whitespace",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: ` get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "values with special characters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `execute-command
|
||||
<arg_key>command</arg_key>
|
||||
<arg_value>ls && echo "done"</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>a < b and c > d</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute-command",
|
||||
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode in function names and values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `获取天气
|
||||
<arg_key>城市</arg_key>
|
||||
<arg_value>北京</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>Hello! 你好! 🌟</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param1</arg_key>
|
||||
<arg_value></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param1": ""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special chars in arg_key names",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param<1></arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>a&b</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>x>y</arg_key>
|
||||
<arg_value>value3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive ampersands",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>test &&&& more</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "test &&&& more"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed special chars together",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value><>&<>&</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "<>&<>&"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "newlines and tabs in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>multiline</arg_key>
|
||||
<arg_value>line1
|
||||
indented line2
|
||||
line3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single and double quotes in values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>quotes</arg_key>
|
||||
<arg_value>She said "Hello's there!"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CDATA-like content that should be treated as text",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>cdata</arg_key>
|
||||
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all special XML entities",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>entities</arg_key>
|
||||
<arg_value><>&'"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"entities": "<>&'""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with multiple parameters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>value3</arg_value>
|
||||
<arg_key>fourth</arg_key>
|
||||
<arg_value>value4</arg_value>
|
||||
<arg_key>fifth</arg_key>
|
||||
<arg_value>value5</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with identical key names but different positions",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>first occurrence</arg_value>
|
||||
<arg_key>other</arg_key>
|
||||
<arg_value>middle</arg_value>
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>second occurrence</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
// Later occurrence should overwrite earlier one
|
||||
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array with mixed types",
|
||||
tools: []api.Tool{
|
||||
tool("process", map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `process
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>[1, "hello", true, null]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: args(`{"items": [1, "hello", true, null]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
tools: []api.Tool{
|
||||
tool("test", map[string]api.ToolProperty{
|
||||
"tags": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `test
|
||||
<arg_key>tags</arg_key>
|
||||
<arg_value>[]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: args(`{"tags": []}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with array of objects",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
|
||||
}),
|
||||
},
|
||||
// <tool_call>TodoWrite
|
||||
// <arg_key>todos</arg_key>
|
||||
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
|
||||
// </tool_call>
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with plain string",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {Type: api.PropertyType{"array", "string"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>Error: could not load todos</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": "Error: could not load todos"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
|
||||
if err != nil {
|
||||
t.Errorf("case %d (%s): %v", i, tc.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
|
||||
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
20
model/parsers/glm47.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
|
||||
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type GLM47Parser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
99
model/parsers/glm47_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47ParserAdd(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init([]api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"count": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
}),
|
||||
}, nil, nil)
|
||||
|
||||
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
|
||||
// so the model output does NOT include the opening <think> tag.
|
||||
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "plan" {
|
||||
t.Fatalf("expected thinking 'plan', got %q", thinking)
|
||||
}
|
||||
if content != "Answer" {
|
||||
t.Fatalf("expected content 'Answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
expectedArgs := args(`{"count": 3, "enabled": true}`)
|
||||
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
|
||||
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserNoThinkingContent(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
// When thinking is enabled but model has no thinking to output,
|
||||
// it should output </think> immediately followed by content.
|
||||
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserThinkingDisabled(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
// When thinking is disabled, parser stays in LookingForThinkingOpen state
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
// Model outputs plain content (prompt ended with </think>)
|
||||
content, thinking, calls, err := parser.Add("Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
|
||||
<arg_key>expr</arg_key>
|
||||
<arg_value>a < b && c > d</arg_value>`}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
expected := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: args(`{"expr": "a < b && c > d"}`),
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(toolCall, expected) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@@ -14,243 +13,114 @@ const (
|
||||
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
|
||||
Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
Nemotron3NanoCollectingContent
|
||||
Nemotron3NanoCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronToolCallClose = "</tool_call>"
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
)
|
||||
|
||||
type Nemotron3NanoParser struct {
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||
|
||||
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.toolParser = &Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
|
||||
// thinking is enabled if user requests it
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
if !thinkingEnabled {
|
||||
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
type nemotronEvent interface {
|
||||
isNemotronEvent()
|
||||
}
|
||||
|
||||
type nemotronEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (nemotronEventThinkingContent) isNemotronEvent() {}
|
||||
func (nemotronEventContent) isNemotronEvent() {}
|
||||
func (nemotronEventToolCall) isNemotronEvent() {}
|
||||
|
||||
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case nemotronEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case nemotronEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case nemotronEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
if p.state == Nemotron3NanoCollectingContent {
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
|
||||
var all []nemotronEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []nemotronEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
|
||||
}
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Nemotron3NanoCollectingThinking:
|
||||
if strings.Contains(bufStr, nemotronThinkClose) {
|
||||
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.WriteString(remainder)
|
||||
// Transition to whitespace-skipping state if buffer is empty,
|
||||
// otherwise go directly to content collection
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
if thinking != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
// We only want to skip whitespace between thinking and content
|
||||
case Nemotron3NanoSkipWhitespaceAfterThinking:
|
||||
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
|
||||
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
||||
if s == "" {
|
||||
return "", "", nil, nil
|
||||
}
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return nil, true
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
|
||||
case Nemotron3NanoCollectingContent:
|
||||
if strings.Contains(bufStr, nemotronToolCallOpen) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
|
||||
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = Nemotron3NanoCollectingToolCalls
|
||||
if content != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
|
||||
// Nemotron3NanoCollectingThinking - buffer and look for end markers
|
||||
p.buffer.WriteString(s)
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
|
||||
thinkIdx := strings.Index(bufStr, nemotronThinkClose)
|
||||
toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
|
||||
|
||||
var endIdx int = -1
|
||||
var remainder string
|
||||
|
||||
if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
|
||||
endIdx = thinkIdx
|
||||
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
|
||||
} else if toolIdx != -1 {
|
||||
endIdx = toolIdx
|
||||
remainder = bufStr[toolIdx:] // Include <tool_call> tag
|
||||
}
|
||||
|
||||
if endIdx != -1 {
|
||||
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
|
||||
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
content, _, calls, err = p.toolParser.Add(remainder, done)
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case Nemotron3NanoCollectingToolCalls:
|
||||
if strings.Contains(bufStr, nemotronToolCallClose) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []nemotronEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, nemotronEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, nemotronToolCallOpen) {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
return content, thinking, calls, err
|
||||
}
|
||||
|
||||
return nil, false
|
||||
// No end marker - emit unambiguous thinking
|
||||
thinking = p.emitThinking(bufStr)
|
||||
return "", thinking, nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
|
||||
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
|
||||
)
|
||||
// emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
|
||||
func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
|
||||
// Check for partial </think> or <tool_call> at end
|
||||
thinkOverlap := overlap(bufStr, nemotronThinkClose)
|
||||
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
|
||||
maxOverlap := max(thinkOverlap, toolOverlap)
|
||||
|
||||
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name
|
||||
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
|
||||
if len(fnMatch) < 2 {
|
||||
return toolCall, nil
|
||||
}
|
||||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
paramName := match[1]
|
||||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
}
|
||||
if maxOverlap > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-maxOverlap]
|
||||
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
|
||||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parseValue(raw, paramType)
|
||||
// No partial tags - emit all but trailing whitespace
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
if wsLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-wsLen]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
// Nothing to hold back
|
||||
p.buffer.Reset()
|
||||
return bufStr
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
|
||||
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
|
||||
func TestNemotron3NanoParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -17,18 +19,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content - no thinking",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple content - thinking disabled",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "thinking then content",
|
||||
input: "Let me think about this...</think>\nHere is my answer.",
|
||||
@@ -43,69 +33,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
|
||||
expectedContent: "The answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters",
|
||||
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then tool call",
|
||||
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
@@ -135,19 +62,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter value",
|
||||
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block - immediate close",
|
||||
input: "</think>\nHere is my answer.",
|
||||
@@ -161,18 +75,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "</think>\nSome content after spurious tag.",
|
||||
},
|
||||
{
|
||||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
input: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "thinking with only whitespace after close tag",
|
||||
input: "My thoughts...</think> \n\t\n Content here.",
|
||||
@@ -180,25 +82,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "My thoughts...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
input: "Hello 世界! 🌍 Ñoño",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello 世界! 🌍 Ñoño",
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric parameter",
|
||||
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -233,6 +116,8 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
|
||||
// Tool call streaming is tested in qwen3coder_test.go.
|
||||
func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -242,18 +127,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content character by character",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "streaming content small tokens",
|
||||
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you today?",
|
||||
},
|
||||
{
|
||||
name: "streaming thinking then content - granular",
|
||||
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
|
||||
@@ -268,45 +141,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process",
|
||||
expectedContent: "The answer.",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call - highly granular",
|
||||
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call - granular",
|
||||
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split character by character",
|
||||
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking close tag split character by character",
|
||||
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
|
||||
@@ -321,22 +155,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Thinking...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters - streaming",
|
||||
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then content then tool call - streaming",
|
||||
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
|
||||
@@ -352,45 +170,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls - streaming",
|
||||
chunks: []string{
|
||||
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
|
||||
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
|
||||
"</function>", "\n", "</tool_call>", "\n",
|
||||
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
|
||||
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
|
||||
"</function>\n", "</tool_call>",
|
||||
},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter - streaming",
|
||||
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block",
|
||||
chunks: []string{"</think>", "\n", "Just content."},
|
||||
@@ -398,12 +177,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "",
|
||||
expectedContent: "Just content.",
|
||||
},
|
||||
{
|
||||
name: "empty input chunks interspersed",
|
||||
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello world!",
|
||||
},
|
||||
{
|
||||
name: "tool call immediately after think close - no content",
|
||||
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
|
||||
@@ -418,25 +191,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with empty parameter value",
|
||||
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tool call tag at end - buffered",
|
||||
chunks: []string{"Here's some content", "<tool"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Here's some content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -572,3 +326,65 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
|
||||
// but the model outputs content + tool call WITHOUT the </think> tag.
|
||||
// The parser should still parse the tool call (content before is treated as thinking).
|
||||
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
|
||||
chunks := []string{
|
||||
"Let", " me", " analyze", " this", ".", "\n",
|
||||
"<tool_call>", "\n",
|
||||
"<function=get_weather>", "\n",
|
||||
"<parameter=city>", "Paris", "</parameter>", "\n",
|
||||
"</function>", "\n",
|
||||
"</tool_call>",
|
||||
}
|
||||
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
|
||||
|
||||
var allContent string
|
||||
var allThinking string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range chunks {
|
||||
content, thinking, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, thinking, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
|
||||
expectedThinking := "Let me analyze this."
|
||||
|
||||
expectedCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if allContent != "" {
|
||||
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
|
||||
}
|
||||
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,8 @@ func ParserForName(name string) Parser {
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -91,6 +91,37 @@ func TestQwenParserStreaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tags split character by character",
|
||||
steps: []step{
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "b", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "/", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
|
||||
@@ -96,3 +96,11 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
110
model/renderers/glm46.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GLM46Renderer struct{}
|
||||
|
||||
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
223
model/renderers/glm46_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
skip string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip != "" {
|
||||
t.Skip(tt.skip)
|
||||
}
|
||||
renderer := &GLM46Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
170
model/renderers/glm47.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// GLM47Renderer renders messages for GLM-4.7 models.
|
||||
//
|
||||
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type GLM47Renderer struct{}
|
||||
|
||||
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatGLM47ToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
191
model/renderers/glm47_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
|
||||
},
|
||||
{
|
||||
name: "system and user",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello there"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with reasoning_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Answer with reasoning."},
|
||||
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "assistant", Content: "It is 22C."},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls and responses",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Compare weather"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "tool", Content: `{"temperature":18}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "preserved thinking in multi-turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think step by step"},
|
||||
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
|
||||
{Role: "user", Content: "Continue"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &GLM47Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,6 +80,8 @@ func rendererForName(name string) Renderer {
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
package renderers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func propsMap(s string) *api.ToolPropertiesMap {
|
||||
var result api.ToolPropertiesMap
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in propsMap(): " + err.Error())
|
||||
}
|
||||
return &result
|
||||
}
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
@@ -733,3 +737,60 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
var data []ImageURLOrData
|
||||
if resp.Image != "" {
|
||||
data = []ImageURLOrData{{B64JSON: resp.Image}}
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -36,10 +37,11 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -60,7 +60,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
@@ -71,10 +71,12 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,19 +84,17 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,38 +154,38 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
@@ -206,7 +206,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
|
||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||
#
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "ERROR: No directory specified" >&2
|
||||
echo "Usage: $0 <base_directory>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
base_dir="$1"
|
||||
|
||||
if [ ! -d "${base_dir}" ]; then
|
||||
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Deduplication complete"
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ var (
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
@@ -775,7 +777,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -850,7 +852,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -916,7 +918,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
{
|
||||
name: "model missing image generation capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedErrMsg: "does not support image generation",
|
||||
},
|
||||
{
|
||||
name: "model with image generation capability",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
191
server/routes.go
@@ -51,7 +51,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +164,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +191,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -344,7 +315,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
// expire the runner if unload is requested (empty prompt, keep alive is 0)
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
@@ -358,6 +329,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
@@ -1124,6 +1101,31 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: m.Config.FileType,
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1206,6 +1208,30 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1574,13 +1600,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -2059,8 +2084,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
if _, ok := c.Get("relax_thinking"); ok {
|
||||
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||
req.Think = nil
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2441,3 +2472,91 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// handleImageGenerate handles image generation requests within GenerateHandler.
|
||||
// This is called when the model has the Image capability.
|
||||
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
|
||||
// Validate image dimensions
|
||||
const maxDimension int32 = 4096
|
||||
if req.Width > maxDimension || req.Height > maxDimension {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner for image generation
|
||||
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request (empty prompt)
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
if s, ok := req.Options["seed"]; ok {
|
||||
switch v := s.(type) {
|
||||
case int:
|
||||
seed = int64(v)
|
||||
case int64:
|
||||
seed = v
|
||||
case float64:
|
||||
seed = int64(v)
|
||||
}
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if cr.Image != "" {
|
||||
res.Image = cr.Image
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.Metrics.TotalDuration = time.Since(checkpointStart)
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
}); err != nil {
|
||||
// Only send JSON error if streaming hasn't started yet
|
||||
// (once streaming starts, headers are committed and we can't change status code)
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2101,3 +2101,95 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateUnload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
loadFnCalled = true
|
||||
req.successCh <- &runnerRef{llama: &mockRunner{}}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
|
||||
loadFnCalled = false
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "",
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.DoneReason != "unload" {
|
||||
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Error("expected done to be true")
|
||||
}
|
||||
|
||||
if loadFnCalled {
|
||||
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -574,7 +574,8 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
CapabilityImage = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
"web_fetch": "Web Fetch",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web fetch, show URL and internet notice
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
// Truncate long URLs
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
|
||||
308
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -130,6 +131,7 @@ type RunOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
Verbose bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
@@ -178,6 +180,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
var latest api.ChatResponse
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -187,6 +190,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
@@ -483,6 +487,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if opts.Verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
@@ -634,12 +642,13 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
// If enableWebsearch is true, the web search tool is registered.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -660,6 +669,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
|
||||
// Register web search and web fetch tools if enabled via flag
|
||||
if enableWebsearch {
|
||||
toolRegistry.RegisterWebSearch()
|
||||
toolRegistry.RegisterWebFetch()
|
||||
}
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
@@ -667,6 +682,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
@@ -677,6 +697,8 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -688,6 +710,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
@@ -707,6 +730,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
@@ -716,6 +743,280 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
wordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
wordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
think = &thinkValue
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think = &api.ThinkValue{Value: false}
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: modelName,
|
||||
Options: options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, " Model\n")
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||
if resp.Details.Family != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||
}
|
||||
if resp.Details.ParameterSize != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||
}
|
||||
if resp.Details.QuantizationLevel != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||
}
|
||||
if len(resp.Capabilities) > 0 {
|
||||
caps := make([]string, len(resp.Capabilities))
|
||||
for i, c := range resp.Capabilities {
|
||||
caps[i] = string(c)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
if len(options) > 0 {
|
||||
fmt.Println("\nUser defined parameters:")
|
||||
for k, v := range options {
|
||||
fmt.Printf(" %-30s %v\n", k, v)
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case system != "":
|
||||
fmt.Println(system + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModelName := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||
|
||||
// Create progress spinner
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// Get client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists and get its info
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For cloud models, no need to preload
|
||||
if info.RemoteHost == "" {
|
||||
// Preload the model by sending an empty generate request
|
||||
req := &api.GenerateRequest{
|
||||
Model: newModelName,
|
||||
Think: think,
|
||||
}
|
||||
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("error loading model: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
modelName = newModelName
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /save <modelname>")
|
||||
continue
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: args[1],
|
||||
From: modelName,
|
||||
Parameters: options,
|
||||
Messages: messages,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
@@ -727,10 +1028,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
@@ -738,6 +1041,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
282
x/create/client/create.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if quantize != "" {
|
||||
return createQuantizedLayers(r, name, dtype, shape, quantize)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
146
x/create/client/create_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
130
x/create/client/quantize.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
func ensureTempDir() string {
|
||||
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
18
x/create/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
399
x/create/create.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
752
x/create/create_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,38 +12,24 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
@@ -74,7 +60,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
||||
quantizeMsg := ""
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
@@ -83,22 +73,52 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
@@ -122,7 +142,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
@@ -141,6 +161,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
} else {
|
||||
cfg["quantization"] = torchDtype
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
@@ -181,3 +211,12 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
# grammar
|
||||
|
||||
Grammar-constrained decoding for LLM outputs using MLX.
|
||||
|
||||
## Performance
|
||||
|
||||
Performance depends on hardware, vocabulary size, grammar, and whether you
|
||||
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
|
||||
setup.
|
||||
|
||||
### Design choices that keep masking fast
|
||||
|
||||
| Technique | Impact |
|
||||
|-----------|--------|
|
||||
| Precomputed token analysis | Terminal matches computed once at startup |
|
||||
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
|
||||
| Partitioned tokens | Exact matches separated from DP candidates |
|
||||
|
||||
### Comparison Notes
|
||||
|
||||
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
|
||||
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
|
||||
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
|
||||
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
)
|
||||
|
||||
// Use built-in JSON grammar
|
||||
g, _ := grammar.JSONGrammar()
|
||||
|
||||
// Or from JSON Schema (OpenAI-compatible)
|
||||
g, _ := schema.Grammar(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`)
|
||||
|
||||
// Or parse custom EBNF
|
||||
g, _ := grammar.ParseEBNF(myGrammar, "root")
|
||||
|
||||
// Create engine with model vocabulary
|
||||
engine, _ := grammar.NewEngine(g, vocab)
|
||||
defer engine.Close()
|
||||
|
||||
// Generation loop
|
||||
for !engine.IsComplete() {
|
||||
logits := model.Forward(tokens)
|
||||
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
|
||||
nextToken := sample(masked)
|
||||
engine.Accept(nextToken)
|
||||
}
|
||||
// Output conforms to the grammar when you only sample from masked tokens and call Accept
|
||||
```
|
||||
|
||||
## EBNF Syntax
|
||||
|
||||
```ebnf
|
||||
rule = expression . # Rule definition (ends with .)
|
||||
"literal" # Literal string
|
||||
"a" … "z" # Character range (inclusive)
|
||||
( a | b ) # Grouping with alternation
|
||||
[ optional ] # Optional (0 or 1)
|
||||
{ repeated } # Repetition (0 or more)
|
||||
```
|
||||
|
||||
### Example: JSON Grammar
|
||||
|
||||
```ebnf
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
```
|
||||
|
||||
### Example: Custom Schema
|
||||
|
||||
```ebnf
|
||||
root = "{" ws name_field "," ws age_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
|
||||
string = "\"" { char } "\"" .
|
||||
char = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
|
||||
ws = { " " | "\n" } .
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
OpenAI-compatible JSON Schema support with automatic EBNF generation:
|
||||
|
||||
```go
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {"$ref": "#/$defs/User"}
|
||||
},
|
||||
"required": ["user"],
|
||||
"$defs": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"role": {"enum": ["admin", "user", "guest"]}
|
||||
},
|
||||
"required": ["name", "email", "role"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
grammar, _ := schema.Grammar(schema)
|
||||
```
|
||||
|
||||
### Supported Features
|
||||
|
||||
| Feature | Example |
|
||||
|---------|---------|
|
||||
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
|
||||
| Objects | `properties`, `required` |
|
||||
| Arrays | `items`, `minItems`, `maxItems` |
|
||||
| Enums | `enum: ["a", "b", "c"]` |
|
||||
| Constants | `const: "value"` |
|
||||
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
|
||||
| References | `$ref: "#/$defs/Name"`, `$defs` |
|
||||
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test -tags mlx ./x/grammar/...
|
||||
|
||||
# Run benchmarks
|
||||
go test -tags mlx ./x/grammar/ -bench=.
|
||||
|
||||
# Compare with llama.cpp (outputs JSON)
|
||||
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
|
||||
|
||||
# Compare with a more complex schema
|
||||
go run -tags mlx ./x/grammar/cmd/compare \
|
||||
-gbnf x/grammar/cmd/compare/complex.gbnf \
|
||||
-schema x/grammar/cmd/compare/complex.schema.json \
|
||||
-vocab-size 128000 -iterations 500
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
|
||||
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
|
||||
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs
|
||||
@@ -1,161 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
|
||||
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
|
||||
type terminalTokenGroups struct {
|
||||
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
|
||||
ExactMatches []int32
|
||||
|
||||
// DPCandidates are tokens that start with this terminal but need DP validation
|
||||
DPCandidates []int
|
||||
}
|
||||
|
||||
// tokenAnalysis contains precomputed terminal matches for a token
|
||||
type tokenAnalysis struct {
|
||||
// The token string
|
||||
Token string
|
||||
|
||||
// TokenID in the vocabulary
|
||||
TokenID int
|
||||
|
||||
// Matches at each byte position
|
||||
// MatchesAtPos[i] = terminals matching at position i with their lengths
|
||||
MatchesAtPos [][]terminalMatch
|
||||
|
||||
// Fast path: if token exactly matches one terminal
|
||||
// -1 if no exact match
|
||||
exactMatch int
|
||||
|
||||
// Whether this token can be consumed at all (has at least one match)
|
||||
HasMatches bool
|
||||
}
|
||||
|
||||
// analyzer precomputes terminal matches for a vocabulary
|
||||
type analyzer struct {
|
||||
matcher *terminalMatcher
|
||||
analyses []tokenAnalysis // Indexed by token ID
|
||||
vocab []string
|
||||
|
||||
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
|
||||
// This enables direct slice appends instead of per-token branching
|
||||
tokensByTerminal []terminalTokenGroups
|
||||
}
|
||||
|
||||
// newAnalyzer creates an analyzer for the given vocabulary and terminals
|
||||
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
|
||||
a := &analyzer{
|
||||
matcher: matcher,
|
||||
analyses: make([]tokenAnalysis, len(vocab)),
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
// Precompute analysis for each token
|
||||
for i, token := range vocab {
|
||||
a.analyses[i] = a.analyze(token, i)
|
||||
}
|
||||
|
||||
// Build pre-partitioned token groups for fast ApplyMask
|
||||
a.buildTokenPartitions()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// analyze computes terminal matches for a single token
|
||||
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
|
||||
analysis := tokenAnalysis{
|
||||
Token: token,
|
||||
TokenID: tokenID,
|
||||
MatchesAtPos: make([][]terminalMatch, len(token)),
|
||||
exactMatch: -1,
|
||||
HasMatches: false,
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Compute matches at each position
|
||||
data := []byte(token)
|
||||
for pos := 0; pos < len(data); pos++ {
|
||||
matches := a.matcher.matchesAt(data, pos)
|
||||
analysis.MatchesAtPos[pos] = matches
|
||||
if len(matches) > 0 {
|
||||
analysis.HasMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match is only valid when a single terminal spans the entire token
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
var exactID int = -1
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
if match.Length != len(token) {
|
||||
continue
|
||||
}
|
||||
if exactID >= 0 && exactID != match.TerminalID {
|
||||
exactID = -1
|
||||
break
|
||||
}
|
||||
exactID = match.TerminalID
|
||||
}
|
||||
analysis.exactMatch = exactID
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// analysis returns the precomputed analysis for a token ID
|
||||
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
|
||||
if tokenID < 0 || tokenID >= len(a.analyses) {
|
||||
return tokenAnalysis{exactMatch: -1}
|
||||
}
|
||||
return a.analyses[tokenID]
|
||||
}
|
||||
|
||||
// vocabSize returns the vocabulary size
|
||||
func (a *analyzer) vocabSize() int {
|
||||
return len(a.vocab)
|
||||
}
|
||||
|
||||
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
|
||||
// This enables ApplyMask to use direct slice appends instead of per-token branching.
|
||||
func (a *analyzer) buildTokenPartitions() {
|
||||
numTerminals := a.matcher.terminalCount()
|
||||
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
|
||||
|
||||
for tokenID, analysis := range a.analyses {
|
||||
if !analysis.HasMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
if analysis.exactMatch >= 0 {
|
||||
// Token exactly matches one terminal - fast path (O(1) validation)
|
||||
tid := analysis.exactMatch
|
||||
a.tokensByTerminal[tid].ExactMatches = append(
|
||||
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
|
||||
} else {
|
||||
// Token needs DP validation - add to all terminals it can start with
|
||||
// This way, when a terminal is valid, we know exactly which tokens need DP
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
seen := make(map[int]bool)
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
tid := match.TerminalID
|
||||
if !seen[tid] {
|
||||
seen[tid] = true
|
||||
a.tokensByTerminal[tid].DPCandidates = append(
|
||||
a.tokensByTerminal[tid].DPCandidates, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminalGroups returns the pre-partitioned token groups for a terminal ID
|
||||
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
|
||||
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
|
||||
return terminalTokenGroups{}
|
||||
}
|
||||
return a.tokensByTerminal[terminalID]
|
||||
}
|
||||
@@ -1,648 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// visitedMapPool reduces allocations for visited maps in bridge operations
|
||||
var visitedMapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(map[stateStackKey]bool, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// getVisitedMap gets a map from the pool
|
||||
func getVisitedMap() map[stateStackKey]bool {
|
||||
return visitedMapPool.Get().(map[stateStackKey]bool)
|
||||
}
|
||||
|
||||
// putVisitedMap returns a map to the pool after clearing it
|
||||
func putVisitedMap(m map[stateStackKey]bool) {
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
visitedMapPool.Put(m)
|
||||
}
|
||||
|
||||
// parserConfig represents a pda state+stack combination
|
||||
type parserConfig struct {
|
||||
state state
|
||||
Stack []stackSymbol
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config
|
||||
func (c *parserConfig) clone() *parserConfig {
|
||||
newStack := make([]stackSymbol, len(c.Stack))
|
||||
copy(newStack, c.Stack)
|
||||
return &parserConfig{
|
||||
state: c.state,
|
||||
Stack: newStack,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns a unique key for this config for deduplication
|
||||
func (c *parserConfig) key() uint64 {
|
||||
h := fnv.New64a()
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
|
||||
h.Write(buf[:])
|
||||
for _, sym := range c.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// configSet represents a set of parser configurations (for nondeterminism)
|
||||
type configSet struct {
|
||||
configs []*parserConfig
|
||||
normalized bool // true if already deduplicated and sorted
|
||||
cachedSig uint64 // cached signature after normalization
|
||||
}
|
||||
|
||||
// newConfigSet creates a new config set with a single configuration
|
||||
func newConfigSet(state state, stack []stackSymbol) *configSet {
|
||||
return &configSet{
|
||||
configs: []*parserConfig{
|
||||
{state: state, Stack: stack},
|
||||
},
|
||||
normalized: true, // single config is already normalized
|
||||
}
|
||||
}
|
||||
|
||||
// normalize deduplicates and sorts configs for stable signatures
|
||||
func (c *configSet) normalize() {
|
||||
if c.normalized || len(c.configs) <= 1 {
|
||||
c.normalized = true
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using a map
|
||||
seen := make(map[uint64]*parserConfig, len(c.configs))
|
||||
for _, cfg := range c.configs {
|
||||
key := cfg.key()
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Extract unique configs
|
||||
unique := make([]*parserConfig, 0, len(seen))
|
||||
for _, cfg := range seen {
|
||||
unique = append(unique, cfg)
|
||||
}
|
||||
|
||||
// Sort by key for deterministic ordering
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return unique[i].key() < unique[j].key()
|
||||
})
|
||||
|
||||
c.configs = unique
|
||||
c.normalized = true
|
||||
}
|
||||
|
||||
// signature returns a hash for cache lookup (normalizes first)
|
||||
func (c *configSet) signature() uint64 {
|
||||
c.normalize()
|
||||
|
||||
// Return cached signature if available
|
||||
if c.cachedSig != 0 {
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
|
||||
// Hash number of configs
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
|
||||
h.Write(buf[:])
|
||||
|
||||
// Hash each config (already sorted)
|
||||
for _, cfg := range c.configs {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
|
||||
h.Write(buf[:])
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
|
||||
h.Write(buf[:])
|
||||
|
||||
for _, sym := range cfg.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
}
|
||||
|
||||
c.cachedSig = h.Sum64()
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
// isEmpty returns true if there are no configurations
|
||||
func (c *configSet) isEmpty() bool {
|
||||
return len(c.configs) == 0
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config set
|
||||
func (c *configSet) clone() *configSet {
|
||||
newConfigs := make([]*parserConfig, len(c.configs))
|
||||
for i, cfg := range c.configs {
|
||||
newConfigs[i] = cfg.clone()
|
||||
}
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// bridge connects token analysis to pda validation
|
||||
type bridge struct {
|
||||
pda *pda
|
||||
analyzer *analyzer
|
||||
}
|
||||
|
||||
// newBridge creates a new bridge
|
||||
func newBridge(pda *pda, analyzer *analyzer) *bridge {
|
||||
return &bridge{
|
||||
pda: pda,
|
||||
analyzer: analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenValid checks if token T can be consumed from the current config
|
||||
// This is the main entry point for token validation
|
||||
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
return b.canAcceptTerminal(config, terminal.Pattern)
|
||||
}
|
||||
|
||||
// General path: DP over (pos, config)
|
||||
return b.dpValidate(&analysis, config)
|
||||
}
|
||||
|
||||
// canAcceptTerminal checks if any config can accept the terminal
|
||||
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
|
||||
for _, cfg := range config.configs {
|
||||
if b.canConfigAcceptTerminal(cfg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canConfigAcceptTerminal checks if a single config can accept the terminal
|
||||
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
|
||||
// Use pooled visited map to reduce allocations
|
||||
visited := getVisitedMap()
|
||||
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
|
||||
putVisitedMap(visited)
|
||||
return result
|
||||
}
|
||||
|
||||
// tryAcceptTerminal recursively tries to accept a terminal from a state
|
||||
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Direct match
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dpValidate runs DP for multi-terminal tokens
|
||||
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
|
||||
// state: (pos, configSet)
|
||||
// Memoize by (pos, configSig)
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
memo := make(map[dpKey]bool)
|
||||
|
||||
var dp func(pos int, config *configSet) bool
|
||||
dp = func(pos int, config *configSet) bool {
|
||||
if pos == len(analysis.Token) {
|
||||
return true // Consumed entire token
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
|
||||
memo[key] = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
memo[key] = false
|
||||
return false
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// advanceConfig advances all configs that can accept the terminal
|
||||
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
|
||||
var newConfigs []*parserConfig
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
advanced := b.advanceSingleConfig(cfg, pattern)
|
||||
newConfigs = append(newConfigs, advanced...)
|
||||
}
|
||||
|
||||
if len(newConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// advanceSingleConfig advances a single config by accepting a terminal
|
||||
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
|
||||
var results []*parserConfig
|
||||
visited := getVisitedMap()
|
||||
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
|
||||
putVisitedMap(visited)
|
||||
return results
|
||||
}
|
||||
|
||||
// collectAdvanced collects all configs reachable by accepting the pattern
|
||||
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Match! Create new config after transition
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
*results = append(*results, &parserConfig{
|
||||
state: t.ToState,
|
||||
Stack: newStack,
|
||||
})
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTokens returns all token IDs that are valid from the given config
|
||||
func (b *bridge) validTokens(config *configSet) []int {
|
||||
var valid []int
|
||||
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
|
||||
if b.IsTokenValid(tokenID, config) {
|
||||
valid = append(valid, tokenID)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// acceptToken attempts to accept a token and returns the new config set
|
||||
// Returns nil if the token is not valid from this config
|
||||
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
newConfig.normalize()
|
||||
return newConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// General path: DP to find final config after consuming token
|
||||
return b.dpAccept(&analysis, config)
|
||||
}
|
||||
|
||||
// dpAccept runs DP to accept a multi-terminal token and return final config
|
||||
// Returns the union of all possible end configurations (preserves nondeterminism)
|
||||
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
// Memoize the configs reachable at each (pos, sig)
|
||||
memo := make(map[dpKey]*configSet)
|
||||
|
||||
var dp func(pos int, config *configSet) *configSet
|
||||
dp = func(pos int, config *configSet) *configSet {
|
||||
if pos == len(analysis.Token) {
|
||||
return config // Consumed entire token, return final config
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Collect all valid result configs from all possible paths
|
||||
var allConfigs []*parserConfig
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
finalConfig := dp(pos+match.Length, newConfig)
|
||||
if finalConfig != nil {
|
||||
// Collect all configs, don't return early
|
||||
allConfigs = append(allConfigs, finalConfig.configs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result: nil if no valid paths, normalized configSet otherwise
|
||||
var result *configSet
|
||||
if len(allConfigs) > 0 {
|
||||
result = &configSet{configs: allConfigs}
|
||||
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
|
||||
}
|
||||
memo[key] = result // Cache normalized result
|
||||
return result
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// isAccepting returns true if any config can reach an accepting state
|
||||
func (b *bridge) isAccepting(config *configSet) bool {
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config check
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canReachAccept checks if we can reach an accepting state via epsilon transitions
|
||||
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if b.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the given config
|
||||
func (b *bridge) validTerminals(config *configSet) []string {
|
||||
seen := make(map[string]bool)
|
||||
var terminals []string
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
|
||||
}
|
||||
|
||||
return terminals
|
||||
}
|
||||
|
||||
// collectValidTerminals collects all reachable terminals
|
||||
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" && !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*terminals = append(*terminals, t.Pattern)
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTerminalIDs returns the IDs of valid terminals from the given config
|
||||
func (b *bridge) validTerminalIDs(config *configSet) []int {
|
||||
seen := make(map[int]bool)
|
||||
var terminalIDs []int
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
|
||||
}
|
||||
|
||||
return terminalIDs
|
||||
}
|
||||
|
||||
// collectValidTerminalIDs collects IDs of all reachable terminals
|
||||
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" {
|
||||
// Look up terminal ID from pattern
|
||||
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
|
||||
seen[tid] = true
|
||||
*terminalIDs = append(*terminalIDs, tid)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
|
||||
|
||||
id-field ::= "\"id\"" ws ":" ws uuid
|
||||
kind-field ::= "\"kind\"" ws ":" ws kind
|
||||
items-field ::= "\"items\"" ws ":" ws items
|
||||
alt-field ::= "\"alt\"" ws ":" ws alt
|
||||
flags-field ::= "\"flags\"" ws ":" ws flags
|
||||
meta-field ::= "\"meta\"" ws ":" ws meta
|
||||
priority-field ::= "\"priority\"" ws ":" ws int
|
||||
|
||||
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
|
||||
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
|
||||
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
|
||||
source ::= "\"api\"" | "\"batch\"" | "\"import\""
|
||||
|
||||
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
|
||||
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
|
||||
|
||||
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
|
||||
item-sku ::= "\"sku\"" ws ":" ws string
|
||||
item-qty ::= "\"qty\"" ws ":" ws int
|
||||
item-status ::= "\"status\"" ws ":" ws status
|
||||
item-notes ::= "\"notes\"" ws ":" ws string
|
||||
|
||||
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
|
||||
meta-created ::= "\"created\"" ws ":" ws date-time
|
||||
meta-source ::= "\"source\"" ws ":" ws source
|
||||
meta-ip ::= "\"ip\"" ws ":" ws ipv4
|
||||
|
||||
alt ::= string | int | "null"
|
||||
|
||||
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
|
||||
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
|
||||
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
|
||||
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
|
||||
int ::= "-"? digit+
|
||||
digit ::= [0-9]
|
||||
hex ::= [0-9a-fA-F]
|
||||
|
||||
ws ::= [ \t\n\r]*
|
||||
@@ -1,46 +0,0 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "format": "uuid" },
|
||||
"kind": { "enum": ["order", "invoice", "shipment"] },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sku": { "type": "string" },
|
||||
"qty": { "type": "integer" },
|
||||
"status": { "enum": ["new", "backorder", "shipped"] },
|
||||
"notes": { "type": "string" }
|
||||
},
|
||||
"required": ["sku", "qty", "status", "notes"]
|
||||
}
|
||||
},
|
||||
"alt": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"minItems": 0,
|
||||
"maxItems": 4,
|
||||
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": { "type": "string", "format": "date-time" },
|
||||
"source": { "enum": ["api", "batch", "import"] },
|
||||
"ip": { "type": "string", "format": "ipv4" }
|
||||
},
|
||||
"required": ["created", "source", "ip"]
|
||||
},
|
||||
"priority": { "type": "integer" }
|
||||
},
|
||||
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
const jsonGBNF = `
|
||||
root ::= value
|
||||
value ::= object | array | string | number | "true" | "false" | "null"
|
||||
object ::= "{" ws "}" | "{" members "}"
|
||||
members ::= member ("," member)*
|
||||
member ::= ws string ws ":" element
|
||||
array ::= "[" ws "]" | "[" elements "]"
|
||||
elements ::= element ("," element)*
|
||||
element ::= ws value ws
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
number ::= "-"? integer fraction? exponent?
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
fraction ::= "." [0-9]+
|
||||
exponent ::= [eE] [+-]? [0-9]+
|
||||
ws ::= [ \t\n\r]*
|
||||
`
|
||||
|
||||
type result struct {
|
||||
vocabSize int `json:"vocab_size"`
|
||||
Iterations int `json:"iterations"`
|
||||
Warmup int `json:"warmup"`
|
||||
ConstrainedSource string `json:"constrained_source"`
|
||||
LlamaSource string `json:"llama_source"`
|
||||
LlamaApply string `json:"llama_apply"`
|
||||
ConstrainedGraph string `json:"constrained_graph"`
|
||||
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
||||
EvalOnly string `json:"eval_only,omitempty"`
|
||||
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
||||
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
||||
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
||||
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
||||
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
||||
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
||||
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
||||
startRule = flag.String("start", "root", "Start rule for EBNF")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
||||
fmt.Fprintln(os.Stderr, "invalid flags")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
vocab := createVocab(*vocabSize)
|
||||
|
||||
if *schemaPath != "" && *ebnfPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
var constrainedSource string
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
switch {
|
||||
case *schemaPath != "":
|
||||
data, readErr := os.ReadFile(*schemaPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = schema.Grammar(string(data))
|
||||
constrainedSource = "schema:" + *schemaPath
|
||||
case *ebnfPath != "":
|
||||
data, readErr := os.ReadFile(*ebnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
||||
constrainedSource = "ebnf:" + *ebnfPath
|
||||
default:
|
||||
compiled, err = grammar.JSONGrammar()
|
||||
constrainedSource = "json"
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
engine, err := grammar.NewEngine(compiled, vocab)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(*vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
masked := engine.ApplyMask(logits)
|
||||
if *withEval {
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
graphAvg := measure(*iterations, func() {
|
||||
_ = engine.ApplyMask(logits)
|
||||
})
|
||||
|
||||
var evalAvg time.Duration
|
||||
var evalOnlyAvg time.Duration
|
||||
if *withEval {
|
||||
evalOnlyAvg = measure(*iterations, func() {
|
||||
baseline := mlx.MulScalar(logits, 1)
|
||||
mlx.Eval(baseline)
|
||||
baseline.Free()
|
||||
})
|
||||
|
||||
evalAvg = measure(*iterations, func() {
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
})
|
||||
}
|
||||
|
||||
vocabIDs := make([]uint32, *vocabSize)
|
||||
for i := range vocabIDs {
|
||||
vocabIDs[i] = uint32(i)
|
||||
}
|
||||
eogTokens := []int32{0}
|
||||
|
||||
gbnf := jsonGBNF
|
||||
llamaSource := "json"
|
||||
if *gbnfPath != "" {
|
||||
data, readErr := os.ReadFile(*gbnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
gbnf = string(data)
|
||||
llamaSource = *gbnfPath
|
||||
}
|
||||
|
||||
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
||||
if llamaGrammar == nil {
|
||||
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer llamaGrammar.Free()
|
||||
|
||||
llamaTokens := make([]llama.TokenData, *vocabSize)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
}
|
||||
|
||||
llamaAvg := measure(*iterations, func() {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
})
|
||||
|
||||
out := result{
|
||||
vocabSize: *vocabSize,
|
||||
Iterations: *iterations,
|
||||
Warmup: *warmup,
|
||||
LlamaApply: llamaAvg.String(),
|
||||
ConstrainedGraph: graphAvg.String(),
|
||||
ConstrainedSource: constrainedSource,
|
||||
LlamaSource: llamaSource,
|
||||
}
|
||||
if *withEval {
|
||||
out.ConstrainedWithEval = evalAvg.String()
|
||||
out.EvalOnly = evalOnlyAvg.String()
|
||||
if evalAvg > evalOnlyAvg {
|
||||
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
||||
} else {
|
||||
out.ConstrainedEvalNet = "0s"
|
||||
}
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
if err := enc.Encode(out); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func measure(iterations int, fn func()) time.Duration {
|
||||
start := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
fn()
|
||||
}
|
||||
return time.Since(start) / time.Duration(iterations)
|
||||
}
|
||||
|
||||
func createVocab(size int) []string {
|
||||
vocab := make([]string, size)
|
||||
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < size {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(jsonTokens); i < size; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Grammar is the compiled form of an EBNF grammar.
|
||||
// It contains terminals, parse tables, and the start state.
|
||||
// Use ParseEBNF or JSONGrammar to create a Grammar.
|
||||
type Grammar struct {
|
||||
// The underlying pda
|
||||
pda *pda
|
||||
|
||||
// Compiled terminal matcher
|
||||
matcher *terminalMatcher
|
||||
}
|
||||
|
||||
// ParseEBNF compiles an EBNF grammar string into a Grammar.
|
||||
// startRule is the name of the start rule (e.g., "root", "json").
|
||||
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
|
||||
pda, err := compileString(ebnf, startRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
|
||||
}
|
||||
|
||||
matcher, err := compileTerminalsStrict(pda)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile terminals: %w", err)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
pda: pda,
|
||||
matcher: matcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JSONGrammar returns the compiled JSON grammar.
|
||||
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
|
||||
func JSONGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
|
||||
// Use this when you want to ensure the output is a JSON object (starts with {).
|
||||
func JSONObjectGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONObjectGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// compileTerminalsStrict builds a matcher that properly handles:
|
||||
// - Escaped literals ("\n", \"", \uXXXX)
|
||||
// - Unicode ranges (rune-based, not byte-based)
|
||||
// - Rejects unsupported patterns with an error (no silent fallback)
|
||||
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
|
||||
m := &terminalMatcher{
|
||||
literalTrie: &trieNode{terminalID: -1},
|
||||
ranges: make([]terminal, 0),
|
||||
terminals: make([]terminal, 0, len(pda.Terminals)),
|
||||
patternToID: make(map[string]int),
|
||||
}
|
||||
|
||||
// Track which pattern produced each unescaped value for collision detection
|
||||
unescapedSource := make(map[string]string) // unescaped -> original pattern
|
||||
|
||||
for i, pattern := range pda.Terminals {
|
||||
terminal, err := parseTerminalPattern(pattern, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
if terminal.Type == terminalLiteral {
|
||||
// Use the unescaped pattern for trie matching
|
||||
m.addLiteralToTrie(terminal.Unescaped, i)
|
||||
|
||||
// Detect collisions between literals that unescape to the same value
|
||||
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
|
||||
if existingPattern != pattern {
|
||||
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
|
||||
existingPattern, pattern, terminal.Unescaped)
|
||||
}
|
||||
} else {
|
||||
unescapedSource[terminal.Unescaped] = pattern
|
||||
}
|
||||
} else if terminal.Type == terminalRange {
|
||||
m.ranges = append(m.ranges, terminal)
|
||||
}
|
||||
|
||||
m.terminals = append(m.terminals, terminal)
|
||||
m.patternToID[pattern] = i
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// parseTerminalPattern parses a terminal pattern and returns a terminal.
|
||||
// Supports:
|
||||
// - Literal strings (with escape sequences)
|
||||
// - Character ranges [X-Y] (unicode-aware)
|
||||
func parseTerminalPattern(pattern string, id int) (terminal, error) {
|
||||
if len(pattern) == 0 {
|
||||
return terminal{}, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
// Check for range pattern: [X-Y]
|
||||
if isUnicodeRangePattern(pattern) {
|
||||
lowRune, highRune, err := parseUnicodeRange(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, err
|
||||
}
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalRange,
|
||||
Pattern: pattern,
|
||||
Unescaped: pattern,
|
||||
LowRune: lowRune,
|
||||
HighRune: highRune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// It's a literal - unescape it
|
||||
unescaped, err := unescapeLiteral(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
|
||||
}
|
||||
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalLiteral,
|
||||
Pattern: pattern,
|
||||
Unescaped: unescaped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
|
||||
func isUnicodeRangePattern(pattern string) bool {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return false
|
||||
}
|
||||
// Find the dash that separates low-high
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
dashIdx := strings.Index(inner, "-")
|
||||
// Handle escaped dash at start
|
||||
if dashIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseUnicodeRange parses [X-Y] into low and high runes
|
||||
func parseUnicodeRange(pattern string) (rune, rune, error) {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return 0, 0, fmt.Errorf("invalid range pattern")
|
||||
}
|
||||
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
|
||||
// Simple case: [a-z] where a and z are single chars
|
||||
if len(inner) == 3 && inner[1] == '-' {
|
||||
return rune(inner[0]), rune(inner[2]), nil
|
||||
}
|
||||
|
||||
// Handle escaped characters like [\u0000-\uFFFF]
|
||||
dashIdx := findRangeDash(inner)
|
||||
if dashIdx < 0 {
|
||||
return 0, 0, fmt.Errorf("no dash in range")
|
||||
}
|
||||
|
||||
lowStr := inner[:dashIdx]
|
||||
highStr := inner[dashIdx+1:]
|
||||
|
||||
lowRune, err := parseRune(lowStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
|
||||
}
|
||||
|
||||
highRune, err := parseRune(highStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
|
||||
}
|
||||
|
||||
if lowRune > highRune {
|
||||
return 0, 0, fmt.Errorf("low bound > high bound")
|
||||
}
|
||||
|
||||
return lowRune, highRune, nil
|
||||
}
|
||||
|
||||
// findRangeDash finds the dash separating low-high in a range pattern
|
||||
func findRangeDash(inner string) int {
|
||||
i := 0
|
||||
for i < len(inner) {
|
||||
if inner[i] == '\\' && i+1 < len(inner) {
|
||||
// Skip escape sequence
|
||||
if inner[i+1] == 'u' && i+6 <= len(inner) {
|
||||
i += 6 // \uXXXX
|
||||
} else {
|
||||
i += 2 // \n, \t, etc.
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inner[i] == '-' && i > 0 {
|
||||
return i
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseRune parses a single rune from a string (handles escapes)
|
||||
func parseRune(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty rune")
|
||||
}
|
||||
|
||||
// Handle escape sequences
|
||||
if s[0] == '\\' {
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("incomplete escape")
|
||||
}
|
||||
switch s[1] {
|
||||
case 'n':
|
||||
return '\n', nil
|
||||
case 't':
|
||||
return '\t', nil
|
||||
case 'r':
|
||||
return '\r', nil
|
||||
case '\\':
|
||||
return '\\', nil
|
||||
case '"':
|
||||
return '"', nil
|
||||
case '\'':
|
||||
return '\'', nil
|
||||
case 'u':
|
||||
if len(s) < 6 {
|
||||
return 0, fmt.Errorf("incomplete unicode escape")
|
||||
}
|
||||
val, err := strconv.ParseInt(s[2:6], 16, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid unicode escape: %w", err)
|
||||
}
|
||||
return rune(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Plain character
|
||||
r, _ := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError {
|
||||
return 0, fmt.Errorf("invalid utf8")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// unescapeLiteral unescapes a literal pattern string
|
||||
func unescapeLiteral(pattern string) (string, error) {
|
||||
// Try strconv.Unquote if it looks quoted
|
||||
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
|
||||
unquoted, err := strconv.Unquote(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unquoted, nil
|
||||
}
|
||||
|
||||
// If no backslashes, return as-is
|
||||
if !strings.Contains(pattern, "\\") {
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// Manual unescape
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if pattern[i] == '\\' && i+1 < len(pattern) {
|
||||
switch pattern[i+1] {
|
||||
case 'n':
|
||||
result.WriteByte('\n')
|
||||
i += 2
|
||||
case 't':
|
||||
result.WriteByte('\t')
|
||||
i += 2
|
||||
case 'r':
|
||||
result.WriteByte('\r')
|
||||
i += 2
|
||||
case '\\':
|
||||
result.WriteByte('\\')
|
||||
i += 2
|
||||
case '"':
|
||||
result.WriteByte('"')
|
||||
i += 2
|
||||
case '\'':
|
||||
result.WriteByte('\'')
|
||||
i += 2
|
||||
case 'u':
|
||||
if i+6 <= len(pattern) {
|
||||
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape at %d", i)
|
||||
}
|
||||
result.WriteRune(rune(val))
|
||||
i += 6
|
||||
} else {
|
||||
return "", fmt.Errorf("incomplete unicode escape at %d", i)
|
||||
}
|
||||
default:
|
||||
// Reject unknown escape sequences
|
||||
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(pattern[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// maskCache provides LRU caching for computed masks.
|
||||
type maskCache struct {
|
||||
cache map[uint64]*list.Element
|
||||
order *list.List
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type maskEntry struct {
|
||||
sig uint64
|
||||
mask *mlx.Array
|
||||
}
|
||||
|
||||
// newMaskCache creates a new mask cache with the given max size
|
||||
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
|
||||
func newMaskCache(maxSize int) *maskCache {
|
||||
if maxSize <= 0 {
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: 0, // Signals disabled
|
||||
}
|
||||
}
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a cached mask, returning nil if not found.
|
||||
// Updates LRU order on cache hit.
|
||||
func (c *maskCache) get(sig uint64) *mlx.Array {
|
||||
if c.maxSize <= 0 {
|
||||
return nil // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if elem, ok := c.cache[sig]; ok {
|
||||
c.order.MoveToFront(elem)
|
||||
return elem.Value.(*maskEntry).mask
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put stores a mask in the cache with LRU eviction.
|
||||
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
|
||||
if c.maxSize <= 0 {
|
||||
return // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, exists := c.cache[sig]; exists {
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity (safe since maxSize > 0)
|
||||
if c.order.Len() >= c.maxSize {
|
||||
oldest := c.order.Back()
|
||||
if oldest != nil {
|
||||
entry := oldest.Value.(*maskEntry)
|
||||
entry.mask.Free()
|
||||
delete(c.cache, entry.sig)
|
||||
c.order.Remove(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
|
||||
c.cache[sig] = elem
|
||||
}
|
||||
|
||||
// clear frees all cached masks.
|
||||
func (c *maskCache) clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
|
||||
elem.Value.(*maskEntry).mask.Free()
|
||||
}
|
||||
c.cache = make(map[uint64]*list.Element)
|
||||
c.order.Init()
|
||||
}
|
||||
|
||||
// size returns the number of cached masks.
|
||||
func (c *maskCache) size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// Engine applies grammar constraints to model outputs using MLX.
|
||||
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
|
||||
type Engine struct {
|
||||
// The compiled grammar
|
||||
grammar *Grammar
|
||||
|
||||
// bridge for token validation
|
||||
bridge *bridge
|
||||
analyzer *analyzer
|
||||
|
||||
// Current parser state (configSet for nondeterminism)
|
||||
configSet *configSet
|
||||
|
||||
// Token vocabulary from the model
|
||||
vocab []string
|
||||
tokenToID map[string]int // O(1) lookup for AcceptString
|
||||
|
||||
// Mask cache: configSig → valid token mask (LRU)
|
||||
maskCache *maskCache
|
||||
|
||||
// Cached negative infinity mask for invalid tokens
|
||||
negInfMask *mlx.Array
|
||||
|
||||
// Threshold for comparison (0.5 since mask values are 0 or 1)
|
||||
threshold *mlx.Array
|
||||
|
||||
// Vocabulary size
|
||||
vocabSize int32
|
||||
|
||||
// Reusable buffers for candidate filtering (avoid allocations)
|
||||
candidateMark []bool // indexed by tokenID, true if in candidate set
|
||||
touched []int // tokenIDs that were marked (for reset)
|
||||
dpCandidates []int // candidates requiring DP validation
|
||||
|
||||
// Reusable buffer for valid token indices (for GPU scatter)
|
||||
validTokenIDs []int32
|
||||
}
|
||||
|
||||
// EngineOption configures an Engine
|
||||
type EngineOption func(*Engine)
|
||||
|
||||
// WithMaskCacheSize sets the mask cache size (default 1024)
|
||||
func WithMaskCacheSize(size int) EngineOption {
|
||||
return func(e *Engine) {
|
||||
e.maskCache = newMaskCache(size)
|
||||
}
|
||||
}
|
||||
|
||||
// NewEngine creates a new constrained decoding engine.
|
||||
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
|
||||
// vocab is the list of token strings from the model's tokenizer.
|
||||
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
|
||||
if grammar == nil {
|
||||
return nil, fmt.Errorf("grammar cannot be nil")
|
||||
}
|
||||
|
||||
// Build analyzer and bridge
|
||||
analyzer := newAnalyzer(vocab, grammar.matcher)
|
||||
bridge := newBridge(grammar.pda, analyzer)
|
||||
|
||||
// Initialize config set from pda initial state
|
||||
initialConfig := newConfigSet(grammar.pda.StartState, nil)
|
||||
|
||||
// Build token lookup map for O(1) AcceptString
|
||||
tokenToID := make(map[string]int, len(vocab))
|
||||
for i, tok := range vocab {
|
||||
tokenToID[tok] = i
|
||||
}
|
||||
|
||||
e := &Engine{
|
||||
grammar: grammar,
|
||||
bridge: bridge,
|
||||
analyzer: analyzer,
|
||||
configSet: initialConfig,
|
||||
vocab: vocab,
|
||||
tokenToID: tokenToID,
|
||||
maskCache: newMaskCache(1024),
|
||||
vocabSize: int32(len(vocab)),
|
||||
candidateMark: make([]bool, len(vocab)),
|
||||
touched: make([]int, 0, 10000),
|
||||
validTokenIDs: make([]int32, 0, 10000),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
// Create the negative infinity mask and threshold
|
||||
if e.vocabSize > 0 {
|
||||
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
|
||||
mlx.Keep(e.negInfMask)
|
||||
|
||||
e.threshold = mlx.NewScalarArray(0.5)
|
||||
mlx.Keep(e.threshold)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// ApplyMask applies grammar constraints to logits.
|
||||
// Returns logits with invalid tokens set to -inf.
|
||||
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
|
||||
sig := e.configSet.signature()
|
||||
|
||||
// Check state cache first (exact state match)
|
||||
if cached := e.maskCache.get(sig); cached != nil {
|
||||
condition := mlx.GreaterEqual(cached, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Compute valid tokens using candidate filtering:
|
||||
// 1. Get valid terminal IDs from current grammar state
|
||||
// 2. Get candidate tokens (those that START with valid terminals)
|
||||
// 3. Run DP validation only on candidates
|
||||
// This is O(candidates) instead of O(vocab_size)
|
||||
|
||||
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
|
||||
|
||||
// Use pre-partitioned token groups for fast candidate building
|
||||
// This eliminates per-token branching - just direct slice appends
|
||||
e.validTokenIDs = e.validTokenIDs[:0]
|
||||
e.dpCandidates = e.dpCandidates[:0]
|
||||
e.touched = e.touched[:0]
|
||||
|
||||
for _, tid := range validTerminalIDs {
|
||||
groups := e.analyzer.terminalGroups(tid)
|
||||
|
||||
// Direct append of exact matches (no per-token check needed)
|
||||
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
|
||||
|
||||
// Collect DP candidates (may have duplicates across terminals)
|
||||
for _, tokenID := range groups.DPCandidates {
|
||||
if !e.candidateMark[tokenID] {
|
||||
e.candidateMark[tokenID] = true
|
||||
e.dpCandidates = append(e.dpCandidates, tokenID)
|
||||
e.touched = append(e.touched, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset marks for next call
|
||||
for _, id := range e.touched {
|
||||
e.candidateMark[id] = false
|
||||
}
|
||||
|
||||
for _, tokenID := range e.dpCandidates {
|
||||
if e.bridge.IsTokenValid(tokenID, e.configSet) {
|
||||
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
|
||||
}
|
||||
}
|
||||
|
||||
// Create and cache the mask on GPU using index updates
|
||||
mask := mlx.Zeros([]int32{e.vocabSize})
|
||||
if len(e.validTokenIDs) > 0 {
|
||||
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
|
||||
values := mlx.Ones(int32(len(e.validTokenIDs)))
|
||||
mask = mlx.PutAlongAxis(mask, indices, values, 0)
|
||||
}
|
||||
mlx.Keep(mask)
|
||||
|
||||
// Cache by state signature
|
||||
e.maskCache.put(sig, mask)
|
||||
|
||||
// Apply mask
|
||||
condition := mlx.GreaterEqual(mask, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Accept processes a token and updates the parser state.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) Accept(tokenID int) bool {
|
||||
if tokenID < 0 || tokenID >= len(e.vocab) {
|
||||
return false
|
||||
}
|
||||
|
||||
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
|
||||
if newConfig == nil {
|
||||
return false
|
||||
}
|
||||
e.configSet = newConfig
|
||||
return true
|
||||
}
|
||||
|
||||
// AcceptString processes a token string directly.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) AcceptString(token string) bool {
|
||||
if id, ok := e.tokenToID[token]; ok {
|
||||
return e.Accept(id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsComplete returns true if the current state is accepting.
|
||||
func (e *Engine) IsComplete() bool {
|
||||
return e.bridge.isAccepting(e.configSet)
|
||||
}
|
||||
|
||||
// Reset resets the engine to initial state.
|
||||
func (e *Engine) Reset() {
|
||||
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
|
||||
}
|
||||
|
||||
// validTokens returns the indices of tokens that are currently valid.
|
||||
func (e *Engine) validTokens() []int {
|
||||
return e.bridge.validTokens(e.configSet)
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the current state.
|
||||
func (e *Engine) validTerminals() []string {
|
||||
return e.bridge.validTerminals(e.configSet)
|
||||
}
|
||||
|
||||
// Close releases MLX resources.
|
||||
func (e *Engine) Close() {
|
||||
if e.maskCache != nil {
|
||||
e.maskCache.clear()
|
||||
}
|
||||
if e.negInfMask != nil {
|
||||
e.negInfMask.Free()
|
||||
}
|
||||
if e.threshold != nil {
|
||||
e.threshold.Free()
|
||||
}
|
||||
}
|
||||
@@ -1,414 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newBenchEngine creates a JSON engine for benchmarks
|
||||
func newBenchEngine(b *testing.B, vocab []string) *Engine {
|
||||
b.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Vocabulary sizes to test (matching real models)
|
||||
var vocabSizes = []int{
|
||||
32000, // Llama 2
|
||||
128000, // Llama 3
|
||||
256000, // Large models
|
||||
}
|
||||
|
||||
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
|
||||
func createBenchVocabN(n int) []string {
|
||||
vocab := make([]string, n)
|
||||
|
||||
// JSON structural tokens (first 20)
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"", "'",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < n {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// String tokens (indices 20-1000)
|
||||
stringIdx := 20
|
||||
for i := 0; i < 980 && stringIdx+i < n; i++ {
|
||||
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
|
||||
}
|
||||
|
||||
// Number tokens (indices 1000-2000)
|
||||
numberIdx := 1000
|
||||
for i := 0; i < 1000 && numberIdx+i < n; i++ {
|
||||
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
|
||||
// Generic tokens (rest)
|
||||
for i := 2000; i < n; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
|
||||
// ============ Core Performance Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMask_32k measures mask application with 32k vocab
|
||||
func BenchmarkApplyMask_32k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 32000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_128k measures mask application with 128k vocab
|
||||
func BenchmarkApplyMask_128k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 128000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_256k measures mask application with 256k vocab
|
||||
func BenchmarkApplyMask_256k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 256000)
|
||||
}
|
||||
|
||||
func benchmarkApplyMask(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(vocabSize), "vocab_size")
|
||||
}
|
||||
|
||||
// ============ state-Dependent Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
|
||||
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
e.AcceptString("{")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkApplyMaskMidObject measures mask in middle of object
|
||||
func BenchmarkApplyMaskMidObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// state: {"key": _value_
|
||||
e.AcceptString("{")
|
||||
e.AcceptString("\"key\"")
|
||||
e.AcceptString(":")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Token Sequence Benchmarks ============
|
||||
|
||||
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
|
||||
func BenchmarkSequence_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
|
||||
func BenchmarkSequence_NestedObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
|
||||
func BenchmarkSequence_LargeArray(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Build sequence: [1, 2, 3, ..., 50]
|
||||
sequence := []string{"["}
|
||||
for i := 1; i <= 50; i++ {
|
||||
sequence = append(sequence, fmt.Sprintf("%d", i))
|
||||
if i < 50 {
|
||||
sequence = append(sequence, ",")
|
||||
}
|
||||
}
|
||||
sequence = append(sequence, "]")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
|
||||
func BenchmarkSequence_MixedTypes(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{",
|
||||
"\"name\"", ":", "\"test\"", ",",
|
||||
"\"count\"", ":", "42", ",",
|
||||
"\"enabled\"", ":", "true", ",",
|
||||
"\"data\"", ":", "null", ",",
|
||||
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
|
||||
"}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// ============ Component Benchmarks ============
|
||||
|
||||
// BenchmarkValidInputs measures pda valid input computation
|
||||
func BenchmarkValidInputs(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.validTerminals()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateTransition measures pda state transition
|
||||
func BenchmarkStateTransition(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
|
||||
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // Graph only, no eval
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNewEngine measures one-time engine initialization.
|
||||
func BenchmarkNewEngine_32k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkNewEngine_128k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkNewEngine(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newBenchEngine(b, vocab)
|
||||
e.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Memory Benchmarks ============
|
||||
|
||||
func BenchmarkMemoryAllocs_32k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkMemoryAllocs_128k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
|
||||
|
||||
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
|
||||
// This simulates adding mask to LLM compute graph
|
||||
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // No Eval - just build graph
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
|
||||
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
var lastMasked *mlx.Array
|
||||
for _, token := range sequence {
|
||||
lastMasked = e.ApplyMask(logits) // Build graph only
|
||||
e.AcceptString(token)
|
||||
}
|
||||
mlx.Eval(lastMasked) // Single eval at end
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newTestEngine creates a JSON engine for testing
|
||||
func newTestEngine(t testing.TB, vocab []string) *Engine {
|
||||
t.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Mock vocabulary for testing
|
||||
func testVocab() []string {
|
||||
return []string{
|
||||
"{", // 0: object start
|
||||
"}", // 1: object end
|
||||
"[", // 2: array start
|
||||
"]", // 3: array end
|
||||
":", // 4: colon
|
||||
",", // 5: comma
|
||||
"\"key\"", // 6: string (quoted)
|
||||
"\"val\"", // 7: string (quoted)
|
||||
"123", // 8: number
|
||||
"-42.5", // 9: number
|
||||
"true", // 10: boolean
|
||||
"false", // 11: boolean
|
||||
"null", // 12: null
|
||||
" ", // 13: whitespace (should be ignored)
|
||||
"\n", // 14: whitespace (should be ignored)
|
||||
"subword", // 15: bare word (NOT valid JSON - requires quotes)
|
||||
"hello", // 16: bare word (NOT valid JSON - requires quotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != int32(len(vocab)) {
|
||||
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
|
||||
}
|
||||
|
||||
// Verify grammar is set
|
||||
if e.grammar == nil {
|
||||
t.Error("grammar should not be nil")
|
||||
}
|
||||
|
||||
// Verify analyzer is set
|
||||
if e.analyzer == nil {
|
||||
t.Error("analyzer should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineValidTokens(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// At start, any value type should be valid
|
||||
validTokens := e.validTokens()
|
||||
|
||||
// Should include object start, array start, strings, numbers, booleans, null
|
||||
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
|
||||
// (JSON strings must be quoted)
|
||||
expectedTokens := map[int]bool{
|
||||
0: true, // {
|
||||
2: true, // [
|
||||
6: true, // "key"
|
||||
7: true, // "val"
|
||||
8: true, // 123
|
||||
9: true, // -42.5
|
||||
10: true, // true
|
||||
11: true, // false
|
||||
12: true, // null
|
||||
}
|
||||
|
||||
// Check that expected tokens are present
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
for idx := range expectedTokens {
|
||||
if !validSet[idx] {
|
||||
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if validSet[15] || validSet[16] {
|
||||
t.Error("bare words should not be valid JSON at the start state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAccept(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { should work
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
|
||||
// After {, valid tokens should be STRING or }
|
||||
validTokens := e.validTokens()
|
||||
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
// STRING tokens (indices 6, 7) and } (index 1) should be valid
|
||||
if !validSet[1] {
|
||||
t.Error("} should be valid after {")
|
||||
}
|
||||
if !validSet[6] && !validSet[7] {
|
||||
t.Error("STRING should be valid after { (for keys)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptSequence(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept {"key": "val"}
|
||||
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
|
||||
|
||||
for i, tokenID := range sequence {
|
||||
if !e.Accept(tokenID) {
|
||||
t.Fatalf("failed to accept token %d (%s) at position %d",
|
||||
tokenID, vocab[tokenID], i)
|
||||
}
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be in complete state after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineReset(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept some tokens
|
||||
e.Accept(0) // {
|
||||
e.Accept(1) // }
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
|
||||
// Reset
|
||||
e.Reset()
|
||||
|
||||
// Should be back to initial state
|
||||
if e.IsComplete() {
|
||||
t.Error("should not be complete after reset")
|
||||
}
|
||||
|
||||
// Should be able to accept new sequence
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept { after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineInvalidTokenRejection(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { first
|
||||
if !e.Accept(0) {
|
||||
t.Fatal("should accept {")
|
||||
}
|
||||
|
||||
// Now try to accept [ which is invalid after {
|
||||
// (After {, only STRING or } are valid)
|
||||
if e.Accept(2) { // [
|
||||
t.Error("should not accept [ after { (expecting STRING or })")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptString(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept using string directly
|
||||
if !e.AcceptString("{") {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.AcceptString("\"key\"") {
|
||||
t.Error("should accept string key")
|
||||
}
|
||||
if !e.AcceptString(":") {
|
||||
t.Error("should accept :")
|
||||
}
|
||||
if !e.AcceptString("123") {
|
||||
t.Error("should accept number")
|
||||
}
|
||||
if !e.AcceptString("}") {
|
||||
t.Error("should accept }")
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBackslashEscape(t *testing.T) {
|
||||
vocab := []string{`"`, `\`, "n", "a"}
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Valid escape: "\n"
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if !e.AcceptString("n") {
|
||||
t.Fatal("should accept escape code")
|
||||
}
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string end")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after escaped string")
|
||||
}
|
||||
|
||||
// Invalid escape: "\a"
|
||||
e.Reset()
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if e.AcceptString("a") {
|
||||
t.Error("should reject invalid escape code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineNegInfMask(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Verify negInfMask exists and has correct shape
|
||||
if e.negInfMask == nil {
|
||||
t.Fatal("negInfMask should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMaskCache(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Create test logits
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
|
||||
// Apply mask - should populate cache
|
||||
_ = e.ApplyMask(logits)
|
||||
|
||||
// Check cache was populated
|
||||
cacheSize := e.maskCache.size()
|
||||
if cacheSize == 0 {
|
||||
t.Error("mask cache should have at least one entry after ApplyMask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEmptyVocab(t *testing.T) {
|
||||
e := newTestEngine(t, []string{})
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 0 {
|
||||
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineLargeVocab(t *testing.T) {
|
||||
// Create a large vocabulary (simulating real model vocab)
|
||||
vocab := make([]string, 32000)
|
||||
for i := range vocab {
|
||||
vocab[i] = "token"
|
||||
}
|
||||
// Add some actual JSON tokens
|
||||
vocab[0] = "{"
|
||||
vocab[1] = "}"
|
||||
vocab[2] = "["
|
||||
vocab[3] = "]"
|
||||
vocab[4] = ":"
|
||||
vocab[5] = ","
|
||||
vocab[6] = "\"test\""
|
||||
vocab[7] = "123"
|
||||
vocab[8] = "true"
|
||||
vocab[9] = "false"
|
||||
vocab[10] = "null"
|
||||
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 32000 {
|
||||
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
|
||||
}
|
||||
|
||||
// Test that it still works correctly
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.Accept(1) { // }
|
||||
t.Error("should accept }")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
|
||||
func TestE2E_JSONDecoding(t *testing.T) {
|
||||
// Create a realistic vocabulary with JSON tokens
|
||||
vocab := []string{
|
||||
// Structural tokens
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
// Keywords
|
||||
"true", "false", "null",
|
||||
// Quoted strings
|
||||
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
|
||||
`"hello"`, `"world"`, `"test"`,
|
||||
// Numbers
|
||||
"0", "1", "2", "3", "42", "123", "-1", "-42",
|
||||
// Whitespace
|
||||
" ", "\n", "\t",
|
||||
// Multi-terminal tokens (span multiple JSON lexemes)
|
||||
`"key":`, `},`, `],`, `{"`, `["`,
|
||||
// Partial/invalid tokens (should be rejected)
|
||||
"invalid", "foo", "bar",
|
||||
}
|
||||
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
// Simple values
|
||||
{"empty object", []string{"{", "}"}, true},
|
||||
{"empty array", []string{"[", "]"}, true},
|
||||
{"true literal", []string{"true"}, true},
|
||||
{"null literal", []string{"null"}, true},
|
||||
{"number", []string{"42"}, true},
|
||||
{"negative number", []string{"-42"}, true},
|
||||
{"quoted string", []string{`"hello"`}, true},
|
||||
|
||||
// Objects
|
||||
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
|
||||
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
|
||||
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
|
||||
|
||||
// Arrays
|
||||
{"array of numbers", []string{"[", "42", "]"}, true},
|
||||
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
|
||||
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
|
||||
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
|
||||
|
||||
// Nested structures
|
||||
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
|
||||
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
|
||||
|
||||
// Invalid sequences
|
||||
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
|
||||
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
|
||||
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
|
||||
{"bare word", []string{"invalid"}, false}, // not valid JSON
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
// Process each token
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return // Already reported error
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
// For invalid sequences, we expect either rejection or incomplete
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
|
||||
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
|
||||
// Simple expression grammar: expr = term { ("+" | "-") term }
|
||||
// term = number | "(" expr ")"
|
||||
// number = digit { digit }
|
||||
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
|
||||
exprGrammar := `
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(exprGrammar, "expr")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary for expression tokens
|
||||
vocab := []string{
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"+", "-", "*", "/",
|
||||
"(", ")",
|
||||
// Multi-digit numbers as single tokens
|
||||
"10", "42", "100", "123",
|
||||
// Invalid tokens
|
||||
"x", "y", "invalid",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single digit", []string{"5"}, true},
|
||||
{"multi-digit", []string{"1", "2", "3"}, true},
|
||||
{"addition", []string{"1", "+", "2"}, true},
|
||||
{"subtraction", []string{"5", "-", "3"}, true},
|
||||
{"multiplication", []string{"2", "*", "3"}, true},
|
||||
{"division", []string{"8", "/", "2"}, true},
|
||||
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
|
||||
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
|
||||
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
|
||||
|
||||
// Invalid
|
||||
{"just operator", []string{"+"}, false},
|
||||
{"double operator", []string{"1", "+", "+", "2"}, false},
|
||||
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
|
||||
{"variable", []string{"x"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
|
||||
func TestE2E_IdentifierGrammar(t *testing.T) {
|
||||
// Identifier grammar using character ranges
|
||||
identGrammar := `
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(identGrammar, "ident")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse identifier grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with letters and digits
|
||||
vocab := []string{
|
||||
"a", "b", "c", "x", "y", "z",
|
||||
"A", "B", "C", "X", "Y", "Z",
|
||||
"_",
|
||||
"0", "1", "2", "9",
|
||||
// Multi-char tokens
|
||||
"foo", "bar", "myVar", "test123",
|
||||
// Invalid starting chars
|
||||
"1abc", "123",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single letter", []string{"a"}, true},
|
||||
{"uppercase", []string{"A"}, true},
|
||||
{"underscore", []string{"_"}, true},
|
||||
{"multi-letter", []string{"a", "b", "c"}, true},
|
||||
{"letter then digit", []string{"x", "1"}, true},
|
||||
{"underscore prefix", []string{"_", "a", "1"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass && allAccepted && !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
|
||||
func TestE2E_UnicodeRange(t *testing.T) {
|
||||
greekGrammar := `
|
||||
greek = "α" … "ω" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(greekGrammar, "greek")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse unicode grammar: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{"α", "β", "ω", "a"}
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
if !engine.AcceptString("β") {
|
||||
t.Error("should accept beta")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete after single rune")
|
||||
}
|
||||
|
||||
engine.Reset()
|
||||
if engine.AcceptString("a") {
|
||||
t.Error("should reject ASCII outside unicode range")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
|
||||
func TestE2E_NondeterminismPreserved(t *testing.T) {
|
||||
// This grammar has nondeterminism: "ab" could be parsed as
|
||||
// a single token or as two tokens "a" "b"
|
||||
ambiguousGrammar := `
|
||||
start = item item .
|
||||
item = "a" | "b" | "ab" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(ambiguousGrammar, "start")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with both single and combined tokens
|
||||
vocab := []string{"a", "b", "ab"}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Test: "ab" "a" should be valid (ab as first item, a as second)
|
||||
t.Run("ab then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a after ab")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then ab", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a")
|
||||
}
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab after a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept first a")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept second a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,614 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package grammar provides GPU-accelerated constrained decoding using MLX.
|
||||
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
|
||||
// For JSON Schema conversion, see the grammar/schema subpackage.
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/ebnf"
|
||||
)
|
||||
|
||||
// stackSymbol represents a symbol that can be pushed onto the pda stack.
|
||||
type stackSymbol int
|
||||
|
||||
const (
|
||||
stackEmpty stackSymbol = iota
|
||||
// Additional stack symbols will be generated per-grammar
|
||||
)
|
||||
|
||||
// state represents a pda state.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateError state = -1
|
||||
stateStart state = 0
|
||||
stateAccept state = 1
|
||||
// Additional states will be generated per-grammar
|
||||
)
|
||||
|
||||
// transition represents a pda transition.
|
||||
// On input matching Pattern, from FromState with stackTop:
|
||||
// - Move to ToState
|
||||
// - Pop StackPop symbols, push StackPush symbols
|
||||
type transition struct {
|
||||
FromState state
|
||||
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
|
||||
Pattern string // Input pattern to match (token or character class)
|
||||
ToState state
|
||||
StackPop int // Number of symbols to pop
|
||||
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
|
||||
}
|
||||
|
||||
// pda represents a compiled pushdown automaton.
|
||||
type pda struct {
|
||||
States int // Total number of states
|
||||
StackSymbols int // Total number of stack symbols
|
||||
StartState state // Initial state
|
||||
AcceptStates map[state]bool // Set of accepting states
|
||||
Transitions map[state][]transition // Transitions indexed by from-state
|
||||
|
||||
// For token-level matching
|
||||
Terminals []string // All terminal symbols (patterns to match)
|
||||
}
|
||||
|
||||
// newPDA creates an empty pda.
|
||||
func newPDA() *pda {
|
||||
return &pda{
|
||||
States: 2, // Error and Start
|
||||
StackSymbols: 1, // Empty
|
||||
StartState: stateStart,
|
||||
AcceptStates: make(map[state]bool),
|
||||
Transitions: make(map[state][]transition),
|
||||
Terminals: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// addState adds a new state and returns its ID.
|
||||
func (p *pda) addState() state {
|
||||
s := state(p.States)
|
||||
p.States++
|
||||
return s
|
||||
}
|
||||
|
||||
// addStackSymbol adds a new stack symbol and returns its ID.
|
||||
func (p *pda) addStackSymbol() stackSymbol {
|
||||
s := stackSymbol(p.StackSymbols)
|
||||
p.StackSymbols++
|
||||
return s
|
||||
}
|
||||
|
||||
// addTransition adds a transition to the pda.
|
||||
func (p *pda) addTransition(t transition) {
|
||||
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
|
||||
}
|
||||
|
||||
// addTerminal registers a terminal pattern and returns its index.
|
||||
func (p *pda) addTerminal(pattern string) int {
|
||||
for i, t := range p.Terminals {
|
||||
if t == pattern {
|
||||
return i
|
||||
}
|
||||
}
|
||||
p.Terminals = append(p.Terminals, pattern)
|
||||
return len(p.Terminals) - 1
|
||||
}
|
||||
|
||||
// compiler compiles EBNF grammars to PDAs.
|
||||
type compiler struct {
|
||||
grammar ebnf.Grammar
|
||||
pda *pda
|
||||
|
||||
// Maps production names to their entry/exit states
|
||||
prodEntry map[string]state
|
||||
prodExit map[string]state
|
||||
}
|
||||
|
||||
// compile parses an EBNF grammar and compiles it to a pda.
|
||||
func compile(name string, src io.Reader, start string) (*pda, error) {
|
||||
grammar, err := ebnf.Parse(name, src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse grammar: %w", err)
|
||||
}
|
||||
|
||||
if err := ebnf.Verify(grammar, start); err != nil {
|
||||
return nil, fmt.Errorf("verify grammar: %w", err)
|
||||
}
|
||||
|
||||
c := &compiler{
|
||||
grammar: grammar,
|
||||
pda: newPDA(),
|
||||
prodEntry: make(map[string]state),
|
||||
prodExit: make(map[string]state),
|
||||
}
|
||||
|
||||
// Create entry/exit states for each production
|
||||
for name := range grammar {
|
||||
c.prodEntry[name] = c.pda.addState()
|
||||
c.prodExit[name] = c.pda.addState()
|
||||
}
|
||||
|
||||
// compile each production
|
||||
for name, prod := range grammar {
|
||||
if err := c.compileProduction(name, prod); err != nil {
|
||||
return nil, fmt.Errorf("compile production %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set start state to entry of start production
|
||||
if entry, ok := c.prodEntry[start]; ok {
|
||||
// Add epsilon transition from pda start to grammar start
|
||||
c.pda.addTransition(transition{
|
||||
FromState: stateStart,
|
||||
Pattern: "", // epsilon
|
||||
ToState: entry,
|
||||
})
|
||||
} else {
|
||||
return nil, fmt.Errorf("start production %q not found", start)
|
||||
}
|
||||
|
||||
// Mark exit of start production as accepting
|
||||
if exit, ok := c.prodExit[start]; ok {
|
||||
c.pda.AcceptStates[exit] = true
|
||||
}
|
||||
|
||||
return c.pda, nil
|
||||
}
|
||||
|
||||
// compileString is a convenience function to compile from a string.
|
||||
func compileString(grammar string, start string) (*pda, error) {
|
||||
return compile("grammar", strings.NewReader(grammar), start)
|
||||
}
|
||||
|
||||
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
|
||||
entry := c.prodEntry[name]
|
||||
exit := c.prodExit[name]
|
||||
|
||||
return c.compileExpr(prod.Expr, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
|
||||
switch e := expr.(type) {
|
||||
case *ebnf.Name:
|
||||
return c.compileName(e, entry, exit)
|
||||
case *ebnf.Token:
|
||||
return c.compileToken(e, entry, exit)
|
||||
case ebnf.Sequence:
|
||||
return c.compileSequence(e, entry, exit)
|
||||
case ebnf.Alternative:
|
||||
return c.compileAlternative(e, entry, exit)
|
||||
case *ebnf.Option:
|
||||
return c.compileOption(e, entry, exit)
|
||||
case *ebnf.Repetition:
|
||||
return c.compileRepetition(e, entry, exit)
|
||||
case *ebnf.Group:
|
||||
return c.compileExpr(e.Body, entry, exit)
|
||||
case *ebnf.Range:
|
||||
return c.compileRange(e, entry, exit)
|
||||
case nil:
|
||||
// Empty production - direct epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
|
||||
// Reference to another production
|
||||
prodName := n.String
|
||||
|
||||
prodEntry, ok := c.prodEntry[prodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined production: %s", prodName)
|
||||
}
|
||||
prodExit := c.prodExit[prodName]
|
||||
// Use a unique stack symbol per call site so returns are unambiguous.
|
||||
stackSym := c.pda.addStackSymbol()
|
||||
|
||||
// Push return address, go to production entry
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "", // epsilon
|
||||
ToState: prodEntry,
|
||||
StackPush: []stackSymbol{stackSym},
|
||||
})
|
||||
|
||||
// On production exit, pop and return
|
||||
c.pda.addTransition(transition{
|
||||
FromState: prodExit,
|
||||
stackTop: stackSym,
|
||||
Pattern: "", // epsilon
|
||||
ToState: exit,
|
||||
StackPop: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
|
||||
// terminal symbol - add transition that consumes this token
|
||||
pattern := t.String
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
|
||||
if len(seq) == 0 {
|
||||
// Empty sequence - epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain: entry -> s1 -> s2 -> ... -> exit
|
||||
current := entry
|
||||
for i, expr := range seq {
|
||||
var next state
|
||||
if i == len(seq)-1 {
|
||||
next = exit
|
||||
} else {
|
||||
next = c.pda.addState()
|
||||
}
|
||||
|
||||
if err := c.compileExpr(expr, current, next); err != nil {
|
||||
return err
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
|
||||
// Each alternative goes from entry to exit
|
||||
for _, expr := range alt {
|
||||
if err := c.compileExpr(expr, entry, exit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
|
||||
// Optional: can skip (epsilon) or take the body
|
||||
|
||||
// Epsilon transition (skip)
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Or take the body
|
||||
return c.compileExpr(opt.Body, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
|
||||
// Repetition {body}: zero or more
|
||||
// entry -> exit (skip)
|
||||
// entry -> body -> entry (loop back)
|
||||
|
||||
// Skip transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Loop: entry -> (body) -> entry
|
||||
return c.compileExpr(rep.Body, entry, entry)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
|
||||
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
|
||||
begin := strings.Trim(r.Begin.String, "\"")
|
||||
end := strings.Trim(r.End.String, "\"")
|
||||
|
||||
// Unescape bounds first (so "\u03b1" works)
|
||||
beginUnesc, err := unescapeLiteral(begin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range begin: %w", err)
|
||||
}
|
||||
endUnesc, err := unescapeLiteral(end)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range end: %w", err)
|
||||
}
|
||||
|
||||
// Validate as single runes (not bytes) for Unicode support
|
||||
beginRunes := []rune(beginUnesc)
|
||||
endRunes := []rune(endUnesc)
|
||||
if len(beginRunes) != 1 || len(endRunes) != 1 {
|
||||
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
|
||||
}
|
||||
|
||||
// Use unescaped rune strings in pattern (consistent with matcher)
|
||||
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtime represents a pda execution instance.
|
||||
type runtime struct {
|
||||
pda *pda
|
||||
state state
|
||||
stack []stackSymbol
|
||||
}
|
||||
|
||||
// newRuntime creates a new pda runtime.
|
||||
func newRuntime(pda *pda) *runtime {
|
||||
return &runtime{
|
||||
pda: pda,
|
||||
state: pda.StartState,
|
||||
stack: make([]stackSymbol, 0, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// stackTop returns the top of the stack, or stackEmpty if empty.
|
||||
func (r *runtime) stackTop() stackSymbol {
|
||||
if len(r.stack) == 0 {
|
||||
return stackEmpty
|
||||
}
|
||||
return r.stack[len(r.stack)-1]
|
||||
}
|
||||
|
||||
// isAccepting returns true if we can reach an accepting state via epsilon transitions
|
||||
// with an empty stack.
|
||||
func (r *runtime) isAccepting() bool {
|
||||
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if r.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
for _, t := range r.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
|
||||
// Check stack constraint
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simulate stack operations
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 && len(newStack) >= t.StackPop {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if r.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset resets the runtime to initial state.
|
||||
func (r *runtime) Reset() {
|
||||
r.state = r.pda.StartState
|
||||
r.stack = r.stack[:0]
|
||||
}
|
||||
|
||||
// validInputs returns all valid input patterns from current state.
|
||||
func (r *runtime) validInputs() []string {
|
||||
var valid []string
|
||||
seen := make(map[string]bool)
|
||||
visited := make(map[stateStackKey]bool)
|
||||
|
||||
// Make a copy of the stack for simulation
|
||||
simStack := make([]stackSymbol, len(r.stack))
|
||||
copy(simStack, r.stack)
|
||||
|
||||
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
|
||||
return valid
|
||||
}
|
||||
|
||||
// stateStackKey is used to detect cycles in epsilon closure
|
||||
type stateStackKey struct {
|
||||
state state
|
||||
stackSig string
|
||||
}
|
||||
|
||||
func stackSignature(stack []stackSymbol) string {
|
||||
if len(stack) == 0 {
|
||||
return ""
|
||||
}
|
||||
buf := make([]byte, len(stack)*8)
|
||||
for i, sym := range stack {
|
||||
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
|
||||
// Get stack top for comparisons
|
||||
stackTop := stackEmpty
|
||||
if len(simStack) > 0 {
|
||||
stackTop = simStack[len(simStack)-1]
|
||||
}
|
||||
|
||||
// Check for cycles to avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[state]
|
||||
|
||||
for _, t := range transitions {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - simulate stack operations
|
||||
newStack := make([]stackSymbol, len(simStack))
|
||||
copy(newStack, simStack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
if len(newStack) < t.StackPop {
|
||||
continue // Can't pop, skip this transition
|
||||
}
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
|
||||
} else {
|
||||
// terminal - add if not seen
|
||||
if !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*valid = append(*valid, t.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchesPattern checks if input matches a pattern.
|
||||
// Patterns can be:
|
||||
// - Exact strings: "a", "{", "true"
|
||||
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
|
||||
func matchesPattern(input, pattern string) bool {
|
||||
// Exact match
|
||||
if input == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for character range pattern [X-Y]
|
||||
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
|
||||
if len(input) != 1 {
|
||||
return false
|
||||
}
|
||||
ch := input[0]
|
||||
low := pattern[1]
|
||||
high := pattern[3]
|
||||
return ch >= low && ch <= high
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Accept tries to accept an input, returning true if successful.
|
||||
func (r *runtime) Accept(input string) bool {
|
||||
return r.accept(input, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[r.state]
|
||||
|
||||
// First, process any epsilon transitions to reach a state that can accept input
|
||||
// This is a simplified version - full implementation would need epsilon closure
|
||||
for _, t := range transitions {
|
||||
if matchesPattern(input, t.Pattern) {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply transition
|
||||
r.applyTransition(t)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try epsilon transitions first
|
||||
for _, t := range transitions {
|
||||
if t.Pattern == "" {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Save state for backtracking
|
||||
oldState := r.state
|
||||
oldStack := make([]stackSymbol, len(r.stack))
|
||||
copy(oldStack, r.stack)
|
||||
|
||||
r.applyTransition(t)
|
||||
|
||||
if r.accept(input, visited) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Backtrack
|
||||
r.state = oldState
|
||||
r.stack = oldStack
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *runtime) applyTransition(t transition) {
|
||||
// Pop
|
||||
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
|
||||
r.stack = r.stack[:len(r.stack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
r.stack = append(r.stack, t.StackPush...)
|
||||
|
||||
// Move to new state
|
||||
r.state = t.ToState
|
||||
}
|
||||
@@ -1,540 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileSimpleGrammar(t *testing.T) {
|
||||
// Simple grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
if pda == nil {
|
||||
t.Fatal("pda is nil")
|
||||
}
|
||||
|
||||
// Should have terminals "a" and "b"
|
||||
if len(pda.Terminals) != 2 {
|
||||
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
|
||||
}
|
||||
|
||||
// Test runtime
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Should accept "a" then "b"
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be in accepting state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileAlternative(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Test accepting "a"
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'a'")
|
||||
}
|
||||
|
||||
// Test accepting "b"
|
||||
rt.Reset()
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// Test rejecting "c"
|
||||
rt.Reset()
|
||||
if rt.Accept("c") {
|
||||
t.Error("should not accept 'c'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRepetition(t *testing.T) {
|
||||
// Grammar: S = {"a"} .
|
||||
grammar := `S = {"a"} .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty should be accepted (zero repetitions)
|
||||
rt := newRuntime(pda)
|
||||
if !rt.isAccepting() {
|
||||
t.Error("empty should be accepting")
|
||||
}
|
||||
|
||||
// "a" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept first 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after one 'a'")
|
||||
}
|
||||
|
||||
// "aa" should be accepted
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept second 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after two 'a's")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileOption(t *testing.T) {
|
||||
// Grammar: S = ["a"] "b" .
|
||||
grammar := `S = ["a"] "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "b" alone should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' alone")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// "ab" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' after 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'ab'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRecursive(t *testing.T) {
|
||||
// Grammar with recursion: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "x" should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'x'")
|
||||
}
|
||||
|
||||
// "(x)" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x' inside parens")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x)'")
|
||||
}
|
||||
|
||||
// "((x))" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept first '('")
|
||||
}
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept second '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept first ')'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept second ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '((x))'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidInputs(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
valid := rt.validInputs()
|
||||
|
||||
// Should have both "a" and "b" as valid
|
||||
hasA, hasB := false, false
|
||||
for _, v := range valid {
|
||||
if v == "a" {
|
||||
hasA = true
|
||||
}
|
||||
if v == "b" {
|
||||
hasB = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasA {
|
||||
t.Error("'a' should be valid input")
|
||||
}
|
||||
if !hasB {
|
||||
t.Error("'b' should be valid input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsAfterAccept tests that validInputs returns correct values
|
||||
// after accepting tokens, ensuring proper stack simulation.
|
||||
func TestValidInputsAfterAccept(t *testing.T) {
|
||||
// Grammar: S = "a" "b" "c" .
|
||||
grammar := `S = "a" "b" "c" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "a" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "a" {
|
||||
t.Errorf("initially expected only 'a', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "a", only "b" should be valid
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "b" {
|
||||
t.Errorf("after 'a', expected only 'b', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "b", only "c" should be valid
|
||||
if !rt.Accept("b") {
|
||||
t.Fatal("failed to accept 'b'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "c" {
|
||||
t.Errorf("after 'ab', expected only 'c', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsWithRepetitionInProduction tests the critical case where
|
||||
// a repetition exists inside a called production. This requires proper
|
||||
// stack simulation to determine when closing symbols are valid.
|
||||
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
|
||||
// Grammar similar to JSON:
|
||||
// S = "(" items ")" .
|
||||
// items = item { "," item } .
|
||||
// item = "x" .
|
||||
grammar := `
|
||||
S = "(" items ")" .
|
||||
items = item { "," item } .
|
||||
item = "x" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "(" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "(" {
|
||||
t.Errorf("initially expected only '(', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept "("
|
||||
if !rt.Accept("(") {
|
||||
t.Fatal("failed to accept '('")
|
||||
}
|
||||
// After "(", should be able to accept "x" (item)
|
||||
valid = rt.validInputs()
|
||||
hasX := false
|
||||
for _, v := range valid {
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasX {
|
||||
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept first item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept 'x'")
|
||||
}
|
||||
// After "(x", should be able to accept "," (more items) OR ")" (end)
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose := false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept comma for another item
|
||||
if !rt.Accept(",") {
|
||||
t.Fatal("failed to accept ','")
|
||||
}
|
||||
// After "(x,", should only be able to accept "x" (next item)
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "x" {
|
||||
t.Errorf("after '(x,', expected only 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept second item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept second 'x'")
|
||||
}
|
||||
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
|
||||
// This tests the stack simulation fix - we need to properly
|
||||
// follow epsilon transitions through the production call stack.
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose = false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Close with ")"
|
||||
if !rt.Accept(")") {
|
||||
t.Fatal("failed to accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x,x)'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
|
||||
func TestValidInputsNestedCalls(t *testing.T) {
|
||||
// Grammar: A = "start" B "end" . B = "middle" .
|
||||
grammar := `
|
||||
A = "start" B "end" .
|
||||
B = "middle" .
|
||||
`
|
||||
pda, err := compileString(grammar, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// After "start", should accept "middle" (from B)
|
||||
rt.Accept("start")
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "middle" {
|
||||
t.Errorf("after 'start', expected 'middle', got %v", valid)
|
||||
}
|
||||
|
||||
// After "start middle", should accept "end"
|
||||
rt.Accept("middle")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "end" {
|
||||
t.Errorf("after 'start middle', expected 'end', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnAddressDisambiguation(t *testing.T) {
|
||||
// Grammar where the same production is called from different contexts:
|
||||
// S = A "x" | "c" A "y" .
|
||||
// A = "a" .
|
||||
grammar := `
|
||||
S = A "x" | "c" A "y" .
|
||||
A = "a" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
if !rt.Accept("c") {
|
||||
t.Fatal("failed to accept 'c'")
|
||||
}
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "y" {
|
||||
t.Errorf("after 'ca', expected only 'y', got %v", valid)
|
||||
}
|
||||
|
||||
rt.Reset()
|
||||
rt.Accept("c")
|
||||
rt.Accept("a")
|
||||
if rt.Accept("x") {
|
||||
t.Error("should not accept 'x' after 'ca'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
|
||||
// which heavily exercise the stack simulation.
|
||||
func TestValidInputsRecursiveWithStack(t *testing.T) {
|
||||
// Grammar: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially: "(" or "x" should be valid
|
||||
valid := rt.validInputs()
|
||||
hasParen, hasX := false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("initially expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "(": "(" or "x" should be valid (nested S)
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((": "(" or "x" should still be valid
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x": only ")" should be valid
|
||||
rt.Accept("x")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x', expected only ')', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x)": only ")" should be valid (closing outer)
|
||||
rt.Accept(")")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x)', expected only ')', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRejectionAfterValid tests that invalid inputs are rejected
|
||||
// at various points in the grammar.
|
||||
func TestRejectionAfterValid(t *testing.T) {
|
||||
// Grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// "b" should be rejected initially
|
||||
if rt.Accept("b") {
|
||||
t.Error("'b' should be rejected initially")
|
||||
}
|
||||
|
||||
// Accept "a"
|
||||
rt.Accept("a")
|
||||
|
||||
// "a" should be rejected after "a"
|
||||
if rt.Accept("a") {
|
||||
t.Error("'a' should be rejected after 'a'")
|
||||
}
|
||||
|
||||
// "c" should be rejected (not in grammar)
|
||||
if rt.Accept("c") {
|
||||
t.Error("'c' should be rejected (not in grammar)")
|
||||
}
|
||||
}
|
||||