Compare commits
33 Commits
mlx-gpu-cd
...
imagegen-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca1faae6a2 | ||
|
|
a077d996e3 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 | ||
|
|
80f3f1bc25 | ||
|
|
b1a0db547b | ||
|
|
75d7b5f926 | ||
|
|
349d814814 | ||
|
|
c8743031e0 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 | ||
|
|
02a2401596 | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
if(APPLE)
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
|
||||
@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -182,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
|
||||
21
api/types.go
@@ -97,6 +97,15 @@ type GenerateRequest struct {
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Width is the width of the generated image (for image generation models).
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image (for image generation models).
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps (for image generation models).
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
@@ -860,6 +869,18 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Status describes the current phase of generation (e.g., "generating image").
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// Total is the total count for the current phase (e.g., total steps).
|
||||
Total int64 `json:"total,omitempty"`
|
||||
|
||||
// Completed is the completed count for the current phase.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Images contains base64-encoded generated images for image generation models.
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
130
cmd/cmd.go
@@ -46,8 +46,9 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -93,14 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Validate model name early to fail fast
|
||||
modelName := args[0]
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
if create.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -133,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -464,14 +538,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -533,9 +599,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -565,7 +640,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -671,7 +746,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -837,11 +916,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1741,15 +1815,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1786,6 +1867,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
13
docs/api.md
@@ -47,6 +47,12 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
Image generation parameters (for image generation models):
|
||||
|
||||
- `width`: (optional) width of the generated image in pixels (default: model-specific)
|
||||
- `height`: (optional) height of the generated image in pixels (default: model-specific)
|
||||
- `steps`: (optional) number of diffusion steps (default: model-specific)
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||
@@ -106,6 +112,13 @@ The final response in the stream also includes additional data about the generat
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
|
||||
For image generation models, the response includes additional fields:
|
||||
|
||||
- `status`: describes the current phase (e.g., "generating image")
|
||||
- `total`: total count for the current phase (e.g., total steps)
|
||||
- `completed`: completed count for the current phase
|
||||
- `images`: array of base64-encoded generated images (in final response)
|
||||
|
||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` \* `10^9`.
|
||||
|
||||
```json
|
||||
|
||||
@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode"
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
BIN
docs/images/marimo-add-model.png
Normal file
|
After Width: | Height: | Size: 174 KiB |
BIN
docs/images/marimo-chat.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
docs/images/marimo-code-completion.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
docs/images/marimo-models.png
Normal file
|
After Width: | Height: | Size: 178 KiB |
BIN
docs/images/marimo-settings.png
Normal file
|
After Width: | Height: | Size: 186 KiB |
BIN
docs/images/onyx-login.png
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
docs/images/onyx-ollama-form.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
BIN
docs/images/onyx-ollama-llm.png
Normal file
|
After Width: | Height: | Size: 300 KiB |
BIN
docs/images/onyx-query.png
Normal file
|
After Width: | Height: | Size: 211 KiB |
@@ -25,6 +25,7 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
@@ -38,7 +39,7 @@ claude --model qwen3-coder
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
73
docs/integrations/marimo.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
63
docs/integrations/onyx.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
18
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -30,8 +30,8 @@ require (
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.31.0
|
||||
golang.org/x/tools v0.40.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -81,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
|
||||
@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -49,6 +50,11 @@ type EmbedWriter struct {
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
done bool
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
@@ -273,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Image generation doesn't support streaming in the OpenAI API sense,
|
||||
// so we only write the response when done with images
|
||||
if generateResponse.Done && len(generateResponse.Images) > 0 {
|
||||
w.done = true
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
@@ -392,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq := openai.FromImageGenerationRequest(req)
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
@@ -441,6 +514,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +552,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +599,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
|
||||
@@ -961,3 +961,143 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
streamFalse := false
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation handler",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a cat",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a dog",
|
||||
"size": "512x512"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a dog",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing prompt error",
|
||||
body: `{
|
||||
"model": "flux"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("transforms generate response to openai format", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.POST("/api/generate", func(c *gin.Context) {
|
||||
// Simulate an image generation response
|
||||
generateResponse := api.GenerateResponse{
|
||||
Done: true,
|
||||
CreatedAt: time.Now(),
|
||||
Images: []string{"base64encodedimage"},
|
||||
}
|
||||
c.JSON(http.StatusOK, generateResponse)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var response openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(response.Data))
|
||||
}
|
||||
if response.Data[0].B64JSON != "base64encodedimage" {
|
||||
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
@@ -733,3 +737,46 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
stream := false
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Stream: &stream,
|
||||
}
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
data := make([]ImageURLOrData, 0)
|
||||
for _, img := range resp.Images {
|
||||
data = append(data, ImageURLOrData{B64JSON: img})
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -36,10 +37,11 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -73,7 +73,7 @@ _build_darwin() {
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
@@ -82,19 +82,19 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,38 +154,40 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
@@ -206,7 +208,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
|
||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||
#
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "ERROR: No directory specified" >&2
|
||||
echo "Usage: $0 <base_directory>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
base_dir="$1"
|
||||
|
||||
if [ ! -d "${base_dir}" ]; then
|
||||
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Deduplication complete"
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
108
server/routes.go
@@ -51,7 +51,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +164,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +191,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -552,6 +523,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -567,6 +541,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||
}
|
||||
|
||||
// Image generation fields
|
||||
if cr.Image != "" {
|
||||
res.Images = []string{cr.Image}
|
||||
}
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Status = "generating image"
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
|
||||
if err != nil {
|
||||
@@ -1124,6 +1108,31 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: m.Config.FileType,
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1206,6 +1215,30 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1574,13 +1607,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -2059,8 +2091,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
if _, ok := c.Get("relax_thinking"); ok {
|
||||
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||
req.Think = nil
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
"web_fetch": "Web Fetch",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web fetch, show URL and internet notice
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
// Truncate long URLs
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
|
||||
308
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -130,6 +131,7 @@ type RunOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
Verbose bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
@@ -178,6 +180,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
var latest api.ChatResponse
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -187,6 +190,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
@@ -483,6 +487,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if opts.Verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
@@ -634,12 +642,13 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
// If enableWebsearch is true, the web search tool is registered.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -660,6 +669,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
|
||||
// Register web search and web fetch tools if enabled via flag
|
||||
if enableWebsearch {
|
||||
toolRegistry.RegisterWebSearch()
|
||||
toolRegistry.RegisterWebFetch()
|
||||
}
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
@@ -667,6 +682,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
@@ -677,6 +697,8 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -688,6 +710,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
@@ -707,6 +730,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
@@ -716,6 +743,280 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
wordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
wordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
think = &thinkValue
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think = &api.ThinkValue{Value: false}
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: modelName,
|
||||
Options: options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, " Model\n")
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||
if resp.Details.Family != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||
}
|
||||
if resp.Details.ParameterSize != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||
}
|
||||
if resp.Details.QuantizationLevel != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||
}
|
||||
if len(resp.Capabilities) > 0 {
|
||||
caps := make([]string, len(resp.Capabilities))
|
||||
for i, c := range resp.Capabilities {
|
||||
caps[i] = string(c)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
if len(options) > 0 {
|
||||
fmt.Println("\nUser defined parameters:")
|
||||
for k, v := range options {
|
||||
fmt.Printf(" %-30s %v\n", k, v)
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case system != "":
|
||||
fmt.Println(system + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModelName := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||
|
||||
// Create progress spinner
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// Get client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists and get its info
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For cloud models, no need to preload
|
||||
if info.RemoteHost == "" {
|
||||
// Preload the model by sending an empty generate request
|
||||
req := &api.GenerateRequest{
|
||||
Model: newModelName,
|
||||
Think: think,
|
||||
}
|
||||
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("error loading model: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
modelName = newModelName
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /save <modelname>")
|
||||
continue
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: args[1],
|
||||
From: modelName,
|
||||
Parameters: options,
|
||||
Messages: messages,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
@@ -727,10 +1028,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
@@ -738,6 +1041,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
282
x/create/client/create.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if quantize != "" {
|
||||
return createQuantizedLayers(r, name, dtype, shape, quantize)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
146
x/create/client/create_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
127
x/create/client/quantize.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
func ensureTempDir() string {
|
||||
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
18
x/create/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
399
x/create/create.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
752
x/create/create_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,38 +12,24 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
@@ -74,7 +60,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
||||
quantizeMsg := ""
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
@@ -83,22 +73,52 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
@@ -122,7 +142,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
@@ -141,6 +161,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
} else {
|
||||
cfg["quantization"] = torchDtype
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
@@ -181,3 +211,12 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
# grammar
|
||||
|
||||
Grammar-constrained decoding for LLM outputs using MLX.
|
||||
|
||||
## Performance
|
||||
|
||||
Performance depends on hardware, vocabulary size, grammar, and whether you
|
||||
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
|
||||
setup.
|
||||
|
||||
### Design choices that keep masking fast
|
||||
|
||||
| Technique | Impact |
|
||||
|-----------|--------|
|
||||
| Precomputed token analysis | Terminal matches computed once at startup |
|
||||
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
|
||||
| Partitioned tokens | Exact matches separated from DP candidates |
|
||||
|
||||
### Comparison Notes
|
||||
|
||||
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
|
||||
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
|
||||
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
|
||||
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
)
|
||||
|
||||
// Use built-in JSON grammar
|
||||
g, _ := grammar.JSONGrammar()
|
||||
|
||||
// Or from JSON Schema (OpenAI-compatible)
|
||||
g, _ := schema.Grammar(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`)
|
||||
|
||||
// Or parse custom EBNF
|
||||
g, _ := grammar.ParseEBNF(myGrammar, "root")
|
||||
|
||||
// Create engine with model vocabulary
|
||||
engine, _ := grammar.NewEngine(g, vocab)
|
||||
defer engine.Close()
|
||||
|
||||
// Generation loop
|
||||
for !engine.IsComplete() {
|
||||
logits := model.Forward(tokens)
|
||||
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
|
||||
nextToken := sample(masked)
|
||||
engine.Accept(nextToken)
|
||||
}
|
||||
// Output conforms to the grammar when you only sample from masked tokens and call Accept
|
||||
```
|
||||
|
||||
## EBNF Syntax
|
||||
|
||||
```ebnf
|
||||
rule = expression . # Rule definition (ends with .)
|
||||
"literal" # Literal string
|
||||
"a" … "z" # Character range (inclusive)
|
||||
( a | b ) # Grouping with alternation
|
||||
[ optional ] # Optional (0 or 1)
|
||||
{ repeated } # Repetition (0 or more)
|
||||
```
|
||||
|
||||
### Example: JSON Grammar
|
||||
|
||||
```ebnf
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
```
|
||||
|
||||
### Example: Custom Schema
|
||||
|
||||
```ebnf
|
||||
root = "{" ws name_field "," ws age_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
|
||||
string = "\"" { char } "\"" .
|
||||
char = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
|
||||
ws = { " " | "\n" } .
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
OpenAI-compatible JSON Schema support with automatic EBNF generation:
|
||||
|
||||
```go
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {"$ref": "#/$defs/User"}
|
||||
},
|
||||
"required": ["user"],
|
||||
"$defs": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"role": {"enum": ["admin", "user", "guest"]}
|
||||
},
|
||||
"required": ["name", "email", "role"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
grammar, _ := schema.Grammar(schema)
|
||||
```
|
||||
|
||||
### Supported Features
|
||||
|
||||
| Feature | Example |
|
||||
|---------|---------|
|
||||
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
|
||||
| Objects | `properties`, `required` |
|
||||
| Arrays | `items`, `minItems`, `maxItems` |
|
||||
| Enums | `enum: ["a", "b", "c"]` |
|
||||
| Constants | `const: "value"` |
|
||||
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
|
||||
| References | `$ref: "#/$defs/Name"`, `$defs` |
|
||||
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test -tags mlx ./x/grammar/...
|
||||
|
||||
# Run benchmarks
|
||||
go test -tags mlx ./x/grammar/ -bench=.
|
||||
|
||||
# Compare with llama.cpp (outputs JSON)
|
||||
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
|
||||
|
||||
# Compare with a more complex schema
|
||||
go run -tags mlx ./x/grammar/cmd/compare \
|
||||
-gbnf x/grammar/cmd/compare/complex.gbnf \
|
||||
-schema x/grammar/cmd/compare/complex.schema.json \
|
||||
-vocab-size 128000 -iterations 500
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
|
||||
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
|
||||
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs
|
||||
@@ -1,161 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
|
||||
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
|
||||
type terminalTokenGroups struct {
|
||||
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
|
||||
ExactMatches []int32
|
||||
|
||||
// DPCandidates are tokens that start with this terminal but need DP validation
|
||||
DPCandidates []int
|
||||
}
|
||||
|
||||
// tokenAnalysis contains precomputed terminal matches for a token
|
||||
type tokenAnalysis struct {
|
||||
// The token string
|
||||
Token string
|
||||
|
||||
// TokenID in the vocabulary
|
||||
TokenID int
|
||||
|
||||
// Matches at each byte position
|
||||
// MatchesAtPos[i] = terminals matching at position i with their lengths
|
||||
MatchesAtPos [][]terminalMatch
|
||||
|
||||
// Fast path: if token exactly matches one terminal
|
||||
// -1 if no exact match
|
||||
exactMatch int
|
||||
|
||||
// Whether this token can be consumed at all (has at least one match)
|
||||
HasMatches bool
|
||||
}
|
||||
|
||||
// analyzer precomputes terminal matches for a vocabulary
|
||||
type analyzer struct {
|
||||
matcher *terminalMatcher
|
||||
analyses []tokenAnalysis // Indexed by token ID
|
||||
vocab []string
|
||||
|
||||
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
|
||||
// This enables direct slice appends instead of per-token branching
|
||||
tokensByTerminal []terminalTokenGroups
|
||||
}
|
||||
|
||||
// newAnalyzer creates an analyzer for the given vocabulary and terminals
|
||||
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
|
||||
a := &analyzer{
|
||||
matcher: matcher,
|
||||
analyses: make([]tokenAnalysis, len(vocab)),
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
// Precompute analysis for each token
|
||||
for i, token := range vocab {
|
||||
a.analyses[i] = a.analyze(token, i)
|
||||
}
|
||||
|
||||
// Build pre-partitioned token groups for fast ApplyMask
|
||||
a.buildTokenPartitions()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// analyze computes terminal matches for a single token
|
||||
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
|
||||
analysis := tokenAnalysis{
|
||||
Token: token,
|
||||
TokenID: tokenID,
|
||||
MatchesAtPos: make([][]terminalMatch, len(token)),
|
||||
exactMatch: -1,
|
||||
HasMatches: false,
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Compute matches at each position
|
||||
data := []byte(token)
|
||||
for pos := 0; pos < len(data); pos++ {
|
||||
matches := a.matcher.matchesAt(data, pos)
|
||||
analysis.MatchesAtPos[pos] = matches
|
||||
if len(matches) > 0 {
|
||||
analysis.HasMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match is only valid when a single terminal spans the entire token
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
var exactID int = -1
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
if match.Length != len(token) {
|
||||
continue
|
||||
}
|
||||
if exactID >= 0 && exactID != match.TerminalID {
|
||||
exactID = -1
|
||||
break
|
||||
}
|
||||
exactID = match.TerminalID
|
||||
}
|
||||
analysis.exactMatch = exactID
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// analysis returns the precomputed analysis for a token ID
|
||||
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
|
||||
if tokenID < 0 || tokenID >= len(a.analyses) {
|
||||
return tokenAnalysis{exactMatch: -1}
|
||||
}
|
||||
return a.analyses[tokenID]
|
||||
}
|
||||
|
||||
// vocabSize returns the vocabulary size
|
||||
func (a *analyzer) vocabSize() int {
|
||||
return len(a.vocab)
|
||||
}
|
||||
|
||||
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
|
||||
// This enables ApplyMask to use direct slice appends instead of per-token branching.
|
||||
func (a *analyzer) buildTokenPartitions() {
|
||||
numTerminals := a.matcher.terminalCount()
|
||||
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
|
||||
|
||||
for tokenID, analysis := range a.analyses {
|
||||
if !analysis.HasMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
if analysis.exactMatch >= 0 {
|
||||
// Token exactly matches one terminal - fast path (O(1) validation)
|
||||
tid := analysis.exactMatch
|
||||
a.tokensByTerminal[tid].ExactMatches = append(
|
||||
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
|
||||
} else {
|
||||
// Token needs DP validation - add to all terminals it can start with
|
||||
// This way, when a terminal is valid, we know exactly which tokens need DP
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
seen := make(map[int]bool)
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
tid := match.TerminalID
|
||||
if !seen[tid] {
|
||||
seen[tid] = true
|
||||
a.tokensByTerminal[tid].DPCandidates = append(
|
||||
a.tokensByTerminal[tid].DPCandidates, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminalGroups returns the pre-partitioned token groups for a terminal ID
|
||||
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
|
||||
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
|
||||
return terminalTokenGroups{}
|
||||
}
|
||||
return a.tokensByTerminal[terminalID]
|
||||
}
|
||||
@@ -1,648 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// visitedMapPool reduces allocations for visited maps in bridge operations
|
||||
var visitedMapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(map[stateStackKey]bool, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// getVisitedMap gets a map from the pool
|
||||
func getVisitedMap() map[stateStackKey]bool {
|
||||
return visitedMapPool.Get().(map[stateStackKey]bool)
|
||||
}
|
||||
|
||||
// putVisitedMap returns a map to the pool after clearing it
|
||||
func putVisitedMap(m map[stateStackKey]bool) {
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
visitedMapPool.Put(m)
|
||||
}
|
||||
|
||||
// parserConfig represents a pda state+stack combination
|
||||
type parserConfig struct {
|
||||
state state
|
||||
Stack []stackSymbol
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config
|
||||
func (c *parserConfig) clone() *parserConfig {
|
||||
newStack := make([]stackSymbol, len(c.Stack))
|
||||
copy(newStack, c.Stack)
|
||||
return &parserConfig{
|
||||
state: c.state,
|
||||
Stack: newStack,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns a unique key for this config for deduplication
|
||||
func (c *parserConfig) key() uint64 {
|
||||
h := fnv.New64a()
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
|
||||
h.Write(buf[:])
|
||||
for _, sym := range c.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// configSet represents a set of parser configurations (for nondeterminism)
|
||||
type configSet struct {
|
||||
configs []*parserConfig
|
||||
normalized bool // true if already deduplicated and sorted
|
||||
cachedSig uint64 // cached signature after normalization
|
||||
}
|
||||
|
||||
// newConfigSet creates a new config set with a single configuration
|
||||
func newConfigSet(state state, stack []stackSymbol) *configSet {
|
||||
return &configSet{
|
||||
configs: []*parserConfig{
|
||||
{state: state, Stack: stack},
|
||||
},
|
||||
normalized: true, // single config is already normalized
|
||||
}
|
||||
}
|
||||
|
||||
// normalize deduplicates and sorts configs for stable signatures
|
||||
func (c *configSet) normalize() {
|
||||
if c.normalized || len(c.configs) <= 1 {
|
||||
c.normalized = true
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using a map
|
||||
seen := make(map[uint64]*parserConfig, len(c.configs))
|
||||
for _, cfg := range c.configs {
|
||||
key := cfg.key()
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Extract unique configs
|
||||
unique := make([]*parserConfig, 0, len(seen))
|
||||
for _, cfg := range seen {
|
||||
unique = append(unique, cfg)
|
||||
}
|
||||
|
||||
// Sort by key for deterministic ordering
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return unique[i].key() < unique[j].key()
|
||||
})
|
||||
|
||||
c.configs = unique
|
||||
c.normalized = true
|
||||
}
|
||||
|
||||
// signature returns a hash for cache lookup (normalizes first)
|
||||
func (c *configSet) signature() uint64 {
|
||||
c.normalize()
|
||||
|
||||
// Return cached signature if available
|
||||
if c.cachedSig != 0 {
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
|
||||
// Hash number of configs
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
|
||||
h.Write(buf[:])
|
||||
|
||||
// Hash each config (already sorted)
|
||||
for _, cfg := range c.configs {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
|
||||
h.Write(buf[:])
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
|
||||
h.Write(buf[:])
|
||||
|
||||
for _, sym := range cfg.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
}
|
||||
|
||||
c.cachedSig = h.Sum64()
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
// isEmpty returns true if there are no configurations
|
||||
func (c *configSet) isEmpty() bool {
|
||||
return len(c.configs) == 0
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config set
|
||||
func (c *configSet) clone() *configSet {
|
||||
newConfigs := make([]*parserConfig, len(c.configs))
|
||||
for i, cfg := range c.configs {
|
||||
newConfigs[i] = cfg.clone()
|
||||
}
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// bridge connects token analysis to pda validation
|
||||
type bridge struct {
|
||||
pda *pda
|
||||
analyzer *analyzer
|
||||
}
|
||||
|
||||
// newBridge creates a new bridge
|
||||
func newBridge(pda *pda, analyzer *analyzer) *bridge {
|
||||
return &bridge{
|
||||
pda: pda,
|
||||
analyzer: analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenValid checks if token T can be consumed from the current config
|
||||
// This is the main entry point for token validation
|
||||
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
return b.canAcceptTerminal(config, terminal.Pattern)
|
||||
}
|
||||
|
||||
// General path: DP over (pos, config)
|
||||
return b.dpValidate(&analysis, config)
|
||||
}
|
||||
|
||||
// canAcceptTerminal checks if any config can accept the terminal
|
||||
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
|
||||
for _, cfg := range config.configs {
|
||||
if b.canConfigAcceptTerminal(cfg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canConfigAcceptTerminal checks if a single config can accept the terminal
|
||||
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
|
||||
// Use pooled visited map to reduce allocations
|
||||
visited := getVisitedMap()
|
||||
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
|
||||
putVisitedMap(visited)
|
||||
return result
|
||||
}
|
||||
|
||||
// tryAcceptTerminal recursively tries to accept a terminal from a state
|
||||
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Direct match
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dpValidate runs DP for multi-terminal tokens
|
||||
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
|
||||
// state: (pos, configSet)
|
||||
// Memoize by (pos, configSig)
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
memo := make(map[dpKey]bool)
|
||||
|
||||
var dp func(pos int, config *configSet) bool
|
||||
dp = func(pos int, config *configSet) bool {
|
||||
if pos == len(analysis.Token) {
|
||||
return true // Consumed entire token
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
|
||||
memo[key] = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
memo[key] = false
|
||||
return false
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// advanceConfig advances all configs that can accept the terminal
|
||||
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
|
||||
var newConfigs []*parserConfig
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
advanced := b.advanceSingleConfig(cfg, pattern)
|
||||
newConfigs = append(newConfigs, advanced...)
|
||||
}
|
||||
|
||||
if len(newConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// advanceSingleConfig advances a single config by accepting a terminal
|
||||
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
|
||||
var results []*parserConfig
|
||||
visited := getVisitedMap()
|
||||
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
|
||||
putVisitedMap(visited)
|
||||
return results
|
||||
}
|
||||
|
||||
// collectAdvanced collects all configs reachable by accepting the pattern
|
||||
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Match! Create new config after transition
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
*results = append(*results, &parserConfig{
|
||||
state: t.ToState,
|
||||
Stack: newStack,
|
||||
})
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTokens returns all token IDs that are valid from the given config
|
||||
func (b *bridge) validTokens(config *configSet) []int {
|
||||
var valid []int
|
||||
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
|
||||
if b.IsTokenValid(tokenID, config) {
|
||||
valid = append(valid, tokenID)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// acceptToken attempts to accept a token and returns the new config set
|
||||
// Returns nil if the token is not valid from this config
|
||||
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
newConfig.normalize()
|
||||
return newConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// General path: DP to find final config after consuming token
|
||||
return b.dpAccept(&analysis, config)
|
||||
}
|
||||
|
||||
// dpAccept runs DP to accept a multi-terminal token and return final config
|
||||
// Returns the union of all possible end configurations (preserves nondeterminism)
|
||||
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
// Memoize the configs reachable at each (pos, sig)
|
||||
memo := make(map[dpKey]*configSet)
|
||||
|
||||
var dp func(pos int, config *configSet) *configSet
|
||||
dp = func(pos int, config *configSet) *configSet {
|
||||
if pos == len(analysis.Token) {
|
||||
return config // Consumed entire token, return final config
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Collect all valid result configs from all possible paths
|
||||
var allConfigs []*parserConfig
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
finalConfig := dp(pos+match.Length, newConfig)
|
||||
if finalConfig != nil {
|
||||
// Collect all configs, don't return early
|
||||
allConfigs = append(allConfigs, finalConfig.configs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result: nil if no valid paths, normalized configSet otherwise
|
||||
var result *configSet
|
||||
if len(allConfigs) > 0 {
|
||||
result = &configSet{configs: allConfigs}
|
||||
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
|
||||
}
|
||||
memo[key] = result // Cache normalized result
|
||||
return result
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// isAccepting returns true if any config can reach an accepting state
|
||||
func (b *bridge) isAccepting(config *configSet) bool {
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config check
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canReachAccept checks if we can reach an accepting state via epsilon transitions
|
||||
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if b.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the given config
|
||||
func (b *bridge) validTerminals(config *configSet) []string {
|
||||
seen := make(map[string]bool)
|
||||
var terminals []string
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
|
||||
}
|
||||
|
||||
return terminals
|
||||
}
|
||||
|
||||
// collectValidTerminals collects all reachable terminals
|
||||
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" && !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*terminals = append(*terminals, t.Pattern)
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTerminalIDs returns the IDs of valid terminals from the given config
|
||||
func (b *bridge) validTerminalIDs(config *configSet) []int {
|
||||
seen := make(map[int]bool)
|
||||
var terminalIDs []int
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
|
||||
}
|
||||
|
||||
return terminalIDs
|
||||
}
|
||||
|
||||
// collectValidTerminalIDs collects IDs of all reachable terminals
|
||||
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" {
|
||||
// Look up terminal ID from pattern
|
||||
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
|
||||
seen[tid] = true
|
||||
*terminalIDs = append(*terminalIDs, tid)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
|
||||
|
||||
id-field ::= "\"id\"" ws ":" ws uuid
|
||||
kind-field ::= "\"kind\"" ws ":" ws kind
|
||||
items-field ::= "\"items\"" ws ":" ws items
|
||||
alt-field ::= "\"alt\"" ws ":" ws alt
|
||||
flags-field ::= "\"flags\"" ws ":" ws flags
|
||||
meta-field ::= "\"meta\"" ws ":" ws meta
|
||||
priority-field ::= "\"priority\"" ws ":" ws int
|
||||
|
||||
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
|
||||
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
|
||||
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
|
||||
source ::= "\"api\"" | "\"batch\"" | "\"import\""
|
||||
|
||||
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
|
||||
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
|
||||
|
||||
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
|
||||
item-sku ::= "\"sku\"" ws ":" ws string
|
||||
item-qty ::= "\"qty\"" ws ":" ws int
|
||||
item-status ::= "\"status\"" ws ":" ws status
|
||||
item-notes ::= "\"notes\"" ws ":" ws string
|
||||
|
||||
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
|
||||
meta-created ::= "\"created\"" ws ":" ws date-time
|
||||
meta-source ::= "\"source\"" ws ":" ws source
|
||||
meta-ip ::= "\"ip\"" ws ":" ws ipv4
|
||||
|
||||
alt ::= string | int | "null"
|
||||
|
||||
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
|
||||
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
|
||||
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
|
||||
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
|
||||
int ::= "-"? digit+
|
||||
digit ::= [0-9]
|
||||
hex ::= [0-9a-fA-F]
|
||||
|
||||
ws ::= [ \t\n\r]*
|
||||
@@ -1,46 +0,0 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "format": "uuid" },
|
||||
"kind": { "enum": ["order", "invoice", "shipment"] },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sku": { "type": "string" },
|
||||
"qty": { "type": "integer" },
|
||||
"status": { "enum": ["new", "backorder", "shipped"] },
|
||||
"notes": { "type": "string" }
|
||||
},
|
||||
"required": ["sku", "qty", "status", "notes"]
|
||||
}
|
||||
},
|
||||
"alt": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"minItems": 0,
|
||||
"maxItems": 4,
|
||||
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": { "type": "string", "format": "date-time" },
|
||||
"source": { "enum": ["api", "batch", "import"] },
|
||||
"ip": { "type": "string", "format": "ipv4" }
|
||||
},
|
||||
"required": ["created", "source", "ip"]
|
||||
},
|
||||
"priority": { "type": "integer" }
|
||||
},
|
||||
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
const jsonGBNF = `
|
||||
root ::= value
|
||||
value ::= object | array | string | number | "true" | "false" | "null"
|
||||
object ::= "{" ws "}" | "{" members "}"
|
||||
members ::= member ("," member)*
|
||||
member ::= ws string ws ":" element
|
||||
array ::= "[" ws "]" | "[" elements "]"
|
||||
elements ::= element ("," element)*
|
||||
element ::= ws value ws
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
number ::= "-"? integer fraction? exponent?
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
fraction ::= "." [0-9]+
|
||||
exponent ::= [eE] [+-]? [0-9]+
|
||||
ws ::= [ \t\n\r]*
|
||||
`
|
||||
|
||||
type result struct {
|
||||
vocabSize int `json:"vocab_size"`
|
||||
Iterations int `json:"iterations"`
|
||||
Warmup int `json:"warmup"`
|
||||
ConstrainedSource string `json:"constrained_source"`
|
||||
LlamaSource string `json:"llama_source"`
|
||||
LlamaApply string `json:"llama_apply"`
|
||||
ConstrainedGraph string `json:"constrained_graph"`
|
||||
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
||||
EvalOnly string `json:"eval_only,omitempty"`
|
||||
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
||||
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
||||
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
||||
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
||||
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
||||
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
||||
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
||||
startRule = flag.String("start", "root", "Start rule for EBNF")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
||||
fmt.Fprintln(os.Stderr, "invalid flags")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
vocab := createVocab(*vocabSize)
|
||||
|
||||
if *schemaPath != "" && *ebnfPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
var constrainedSource string
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
switch {
|
||||
case *schemaPath != "":
|
||||
data, readErr := os.ReadFile(*schemaPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = schema.Grammar(string(data))
|
||||
constrainedSource = "schema:" + *schemaPath
|
||||
case *ebnfPath != "":
|
||||
data, readErr := os.ReadFile(*ebnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
||||
constrainedSource = "ebnf:" + *ebnfPath
|
||||
default:
|
||||
compiled, err = grammar.JSONGrammar()
|
||||
constrainedSource = "json"
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
engine, err := grammar.NewEngine(compiled, vocab)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(*vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
masked := engine.ApplyMask(logits)
|
||||
if *withEval {
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
graphAvg := measure(*iterations, func() {
|
||||
_ = engine.ApplyMask(logits)
|
||||
})
|
||||
|
||||
var evalAvg time.Duration
|
||||
var evalOnlyAvg time.Duration
|
||||
if *withEval {
|
||||
evalOnlyAvg = measure(*iterations, func() {
|
||||
baseline := mlx.MulScalar(logits, 1)
|
||||
mlx.Eval(baseline)
|
||||
baseline.Free()
|
||||
})
|
||||
|
||||
evalAvg = measure(*iterations, func() {
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
})
|
||||
}
|
||||
|
||||
vocabIDs := make([]uint32, *vocabSize)
|
||||
for i := range vocabIDs {
|
||||
vocabIDs[i] = uint32(i)
|
||||
}
|
||||
eogTokens := []int32{0}
|
||||
|
||||
gbnf := jsonGBNF
|
||||
llamaSource := "json"
|
||||
if *gbnfPath != "" {
|
||||
data, readErr := os.ReadFile(*gbnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
gbnf = string(data)
|
||||
llamaSource = *gbnfPath
|
||||
}
|
||||
|
||||
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
||||
if llamaGrammar == nil {
|
||||
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer llamaGrammar.Free()
|
||||
|
||||
llamaTokens := make([]llama.TokenData, *vocabSize)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
}
|
||||
|
||||
llamaAvg := measure(*iterations, func() {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
})
|
||||
|
||||
out := result{
|
||||
vocabSize: *vocabSize,
|
||||
Iterations: *iterations,
|
||||
Warmup: *warmup,
|
||||
LlamaApply: llamaAvg.String(),
|
||||
ConstrainedGraph: graphAvg.String(),
|
||||
ConstrainedSource: constrainedSource,
|
||||
LlamaSource: llamaSource,
|
||||
}
|
||||
if *withEval {
|
||||
out.ConstrainedWithEval = evalAvg.String()
|
||||
out.EvalOnly = evalOnlyAvg.String()
|
||||
if evalAvg > evalOnlyAvg {
|
||||
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
||||
} else {
|
||||
out.ConstrainedEvalNet = "0s"
|
||||
}
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
if err := enc.Encode(out); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func measure(iterations int, fn func()) time.Duration {
|
||||
start := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
fn()
|
||||
}
|
||||
return time.Since(start) / time.Duration(iterations)
|
||||
}
|
||||
|
||||
func createVocab(size int) []string {
|
||||
vocab := make([]string, size)
|
||||
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < size {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(jsonTokens); i < size; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Grammar is the compiled form of an EBNF grammar.
|
||||
// It contains terminals, parse tables, and the start state.
|
||||
// Use ParseEBNF or JSONGrammar to create a Grammar.
|
||||
type Grammar struct {
|
||||
// The underlying pda
|
||||
pda *pda
|
||||
|
||||
// Compiled terminal matcher
|
||||
matcher *terminalMatcher
|
||||
}
|
||||
|
||||
// ParseEBNF compiles an EBNF grammar string into a Grammar.
|
||||
// startRule is the name of the start rule (e.g., "root", "json").
|
||||
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
|
||||
pda, err := compileString(ebnf, startRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
|
||||
}
|
||||
|
||||
matcher, err := compileTerminalsStrict(pda)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile terminals: %w", err)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
pda: pda,
|
||||
matcher: matcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JSONGrammar returns the compiled JSON grammar.
|
||||
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
|
||||
func JSONGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
|
||||
// Use this when you want to ensure the output is a JSON object (starts with {).
|
||||
func JSONObjectGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONObjectGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// compileTerminalsStrict builds a matcher that properly handles:
|
||||
// - Escaped literals ("\n", \"", \uXXXX)
|
||||
// - Unicode ranges (rune-based, not byte-based)
|
||||
// - Rejects unsupported patterns with an error (no silent fallback)
|
||||
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
|
||||
m := &terminalMatcher{
|
||||
literalTrie: &trieNode{terminalID: -1},
|
||||
ranges: make([]terminal, 0),
|
||||
terminals: make([]terminal, 0, len(pda.Terminals)),
|
||||
patternToID: make(map[string]int),
|
||||
}
|
||||
|
||||
// Track which pattern produced each unescaped value for collision detection
|
||||
unescapedSource := make(map[string]string) // unescaped -> original pattern
|
||||
|
||||
for i, pattern := range pda.Terminals {
|
||||
terminal, err := parseTerminalPattern(pattern, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
if terminal.Type == terminalLiteral {
|
||||
// Use the unescaped pattern for trie matching
|
||||
m.addLiteralToTrie(terminal.Unescaped, i)
|
||||
|
||||
// Detect collisions between literals that unescape to the same value
|
||||
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
|
||||
if existingPattern != pattern {
|
||||
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
|
||||
existingPattern, pattern, terminal.Unescaped)
|
||||
}
|
||||
} else {
|
||||
unescapedSource[terminal.Unescaped] = pattern
|
||||
}
|
||||
} else if terminal.Type == terminalRange {
|
||||
m.ranges = append(m.ranges, terminal)
|
||||
}
|
||||
|
||||
m.terminals = append(m.terminals, terminal)
|
||||
m.patternToID[pattern] = i
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// parseTerminalPattern parses a terminal pattern and returns a terminal.
|
||||
// Supports:
|
||||
// - Literal strings (with escape sequences)
|
||||
// - Character ranges [X-Y] (unicode-aware)
|
||||
func parseTerminalPattern(pattern string, id int) (terminal, error) {
|
||||
if len(pattern) == 0 {
|
||||
return terminal{}, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
// Check for range pattern: [X-Y]
|
||||
if isUnicodeRangePattern(pattern) {
|
||||
lowRune, highRune, err := parseUnicodeRange(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, err
|
||||
}
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalRange,
|
||||
Pattern: pattern,
|
||||
Unescaped: pattern,
|
||||
LowRune: lowRune,
|
||||
HighRune: highRune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// It's a literal - unescape it
|
||||
unescaped, err := unescapeLiteral(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
|
||||
}
|
||||
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalLiteral,
|
||||
Pattern: pattern,
|
||||
Unescaped: unescaped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
|
||||
func isUnicodeRangePattern(pattern string) bool {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return false
|
||||
}
|
||||
// Find the dash that separates low-high
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
dashIdx := strings.Index(inner, "-")
|
||||
// Handle escaped dash at start
|
||||
if dashIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseUnicodeRange parses [X-Y] into low and high runes
|
||||
func parseUnicodeRange(pattern string) (rune, rune, error) {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return 0, 0, fmt.Errorf("invalid range pattern")
|
||||
}
|
||||
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
|
||||
// Simple case: [a-z] where a and z are single chars
|
||||
if len(inner) == 3 && inner[1] == '-' {
|
||||
return rune(inner[0]), rune(inner[2]), nil
|
||||
}
|
||||
|
||||
// Handle escaped characters like [\u0000-\uFFFF]
|
||||
dashIdx := findRangeDash(inner)
|
||||
if dashIdx < 0 {
|
||||
return 0, 0, fmt.Errorf("no dash in range")
|
||||
}
|
||||
|
||||
lowStr := inner[:dashIdx]
|
||||
highStr := inner[dashIdx+1:]
|
||||
|
||||
lowRune, err := parseRune(lowStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
|
||||
}
|
||||
|
||||
highRune, err := parseRune(highStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
|
||||
}
|
||||
|
||||
if lowRune > highRune {
|
||||
return 0, 0, fmt.Errorf("low bound > high bound")
|
||||
}
|
||||
|
||||
return lowRune, highRune, nil
|
||||
}
|
||||
|
||||
// findRangeDash finds the dash separating low-high in a range pattern
|
||||
func findRangeDash(inner string) int {
|
||||
i := 0
|
||||
for i < len(inner) {
|
||||
if inner[i] == '\\' && i+1 < len(inner) {
|
||||
// Skip escape sequence
|
||||
if inner[i+1] == 'u' && i+6 <= len(inner) {
|
||||
i += 6 // \uXXXX
|
||||
} else {
|
||||
i += 2 // \n, \t, etc.
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inner[i] == '-' && i > 0 {
|
||||
return i
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseRune parses a single rune from a string (handles escapes)
|
||||
func parseRune(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty rune")
|
||||
}
|
||||
|
||||
// Handle escape sequences
|
||||
if s[0] == '\\' {
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("incomplete escape")
|
||||
}
|
||||
switch s[1] {
|
||||
case 'n':
|
||||
return '\n', nil
|
||||
case 't':
|
||||
return '\t', nil
|
||||
case 'r':
|
||||
return '\r', nil
|
||||
case '\\':
|
||||
return '\\', nil
|
||||
case '"':
|
||||
return '"', nil
|
||||
case '\'':
|
||||
return '\'', nil
|
||||
case 'u':
|
||||
if len(s) < 6 {
|
||||
return 0, fmt.Errorf("incomplete unicode escape")
|
||||
}
|
||||
val, err := strconv.ParseInt(s[2:6], 16, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid unicode escape: %w", err)
|
||||
}
|
||||
return rune(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Plain character
|
||||
r, _ := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError {
|
||||
return 0, fmt.Errorf("invalid utf8")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// unescapeLiteral unescapes a literal pattern string
|
||||
func unescapeLiteral(pattern string) (string, error) {
|
||||
// Try strconv.Unquote if it looks quoted
|
||||
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
|
||||
unquoted, err := strconv.Unquote(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unquoted, nil
|
||||
}
|
||||
|
||||
// If no backslashes, return as-is
|
||||
if !strings.Contains(pattern, "\\") {
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// Manual unescape
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if pattern[i] == '\\' && i+1 < len(pattern) {
|
||||
switch pattern[i+1] {
|
||||
case 'n':
|
||||
result.WriteByte('\n')
|
||||
i += 2
|
||||
case 't':
|
||||
result.WriteByte('\t')
|
||||
i += 2
|
||||
case 'r':
|
||||
result.WriteByte('\r')
|
||||
i += 2
|
||||
case '\\':
|
||||
result.WriteByte('\\')
|
||||
i += 2
|
||||
case '"':
|
||||
result.WriteByte('"')
|
||||
i += 2
|
||||
case '\'':
|
||||
result.WriteByte('\'')
|
||||
i += 2
|
||||
case 'u':
|
||||
if i+6 <= len(pattern) {
|
||||
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape at %d", i)
|
||||
}
|
||||
result.WriteRune(rune(val))
|
||||
i += 6
|
||||
} else {
|
||||
return "", fmt.Errorf("incomplete unicode escape at %d", i)
|
||||
}
|
||||
default:
|
||||
// Reject unknown escape sequences
|
||||
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(pattern[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// maskCache provides LRU caching for computed masks.
|
||||
type maskCache struct {
|
||||
cache map[uint64]*list.Element
|
||||
order *list.List
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type maskEntry struct {
|
||||
sig uint64
|
||||
mask *mlx.Array
|
||||
}
|
||||
|
||||
// newMaskCache creates a new mask cache with the given max size
|
||||
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
|
||||
func newMaskCache(maxSize int) *maskCache {
|
||||
if maxSize <= 0 {
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: 0, // Signals disabled
|
||||
}
|
||||
}
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a cached mask, returning nil if not found.
|
||||
// Updates LRU order on cache hit.
|
||||
func (c *maskCache) get(sig uint64) *mlx.Array {
|
||||
if c.maxSize <= 0 {
|
||||
return nil // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if elem, ok := c.cache[sig]; ok {
|
||||
c.order.MoveToFront(elem)
|
||||
return elem.Value.(*maskEntry).mask
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put stores a mask in the cache with LRU eviction.
|
||||
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
|
||||
if c.maxSize <= 0 {
|
||||
return // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, exists := c.cache[sig]; exists {
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity (safe since maxSize > 0)
|
||||
if c.order.Len() >= c.maxSize {
|
||||
oldest := c.order.Back()
|
||||
if oldest != nil {
|
||||
entry := oldest.Value.(*maskEntry)
|
||||
entry.mask.Free()
|
||||
delete(c.cache, entry.sig)
|
||||
c.order.Remove(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
|
||||
c.cache[sig] = elem
|
||||
}
|
||||
|
||||
// clear frees all cached masks.
|
||||
func (c *maskCache) clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
|
||||
elem.Value.(*maskEntry).mask.Free()
|
||||
}
|
||||
c.cache = make(map[uint64]*list.Element)
|
||||
c.order.Init()
|
||||
}
|
||||
|
||||
// size returns the number of cached masks.
|
||||
func (c *maskCache) size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// Engine applies grammar constraints to model outputs using MLX.
|
||||
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
|
||||
type Engine struct {
|
||||
// The compiled grammar
|
||||
grammar *Grammar
|
||||
|
||||
// bridge for token validation
|
||||
bridge *bridge
|
||||
analyzer *analyzer
|
||||
|
||||
// Current parser state (configSet for nondeterminism)
|
||||
configSet *configSet
|
||||
|
||||
// Token vocabulary from the model
|
||||
vocab []string
|
||||
tokenToID map[string]int // O(1) lookup for AcceptString
|
||||
|
||||
// Mask cache: configSig → valid token mask (LRU)
|
||||
maskCache *maskCache
|
||||
|
||||
// Cached negative infinity mask for invalid tokens
|
||||
negInfMask *mlx.Array
|
||||
|
||||
// Threshold for comparison (0.5 since mask values are 0 or 1)
|
||||
threshold *mlx.Array
|
||||
|
||||
// Vocabulary size
|
||||
vocabSize int32
|
||||
|
||||
// Reusable buffers for candidate filtering (avoid allocations)
|
||||
candidateMark []bool // indexed by tokenID, true if in candidate set
|
||||
touched []int // tokenIDs that were marked (for reset)
|
||||
dpCandidates []int // candidates requiring DP validation
|
||||
|
||||
// Reusable buffer for valid token indices (for GPU scatter)
|
||||
validTokenIDs []int32
|
||||
}
|
||||
|
||||
// EngineOption configures an Engine
|
||||
type EngineOption func(*Engine)
|
||||
|
||||
// WithMaskCacheSize sets the mask cache size (default 1024)
|
||||
func WithMaskCacheSize(size int) EngineOption {
|
||||
return func(e *Engine) {
|
||||
e.maskCache = newMaskCache(size)
|
||||
}
|
||||
}
|
||||
|
||||
// NewEngine creates a new constrained decoding engine.
|
||||
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
|
||||
// vocab is the list of token strings from the model's tokenizer.
|
||||
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
|
||||
if grammar == nil {
|
||||
return nil, fmt.Errorf("grammar cannot be nil")
|
||||
}
|
||||
|
||||
// Build analyzer and bridge
|
||||
analyzer := newAnalyzer(vocab, grammar.matcher)
|
||||
bridge := newBridge(grammar.pda, analyzer)
|
||||
|
||||
// Initialize config set from pda initial state
|
||||
initialConfig := newConfigSet(grammar.pda.StartState, nil)
|
||||
|
||||
// Build token lookup map for O(1) AcceptString
|
||||
tokenToID := make(map[string]int, len(vocab))
|
||||
for i, tok := range vocab {
|
||||
tokenToID[tok] = i
|
||||
}
|
||||
|
||||
e := &Engine{
|
||||
grammar: grammar,
|
||||
bridge: bridge,
|
||||
analyzer: analyzer,
|
||||
configSet: initialConfig,
|
||||
vocab: vocab,
|
||||
tokenToID: tokenToID,
|
||||
maskCache: newMaskCache(1024),
|
||||
vocabSize: int32(len(vocab)),
|
||||
candidateMark: make([]bool, len(vocab)),
|
||||
touched: make([]int, 0, 10000),
|
||||
validTokenIDs: make([]int32, 0, 10000),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
// Create the negative infinity mask and threshold
|
||||
if e.vocabSize > 0 {
|
||||
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
|
||||
mlx.Keep(e.negInfMask)
|
||||
|
||||
e.threshold = mlx.NewScalarArray(0.5)
|
||||
mlx.Keep(e.threshold)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// ApplyMask applies grammar constraints to logits.
|
||||
// Returns logits with invalid tokens set to -inf.
|
||||
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
|
||||
sig := e.configSet.signature()
|
||||
|
||||
// Check state cache first (exact state match)
|
||||
if cached := e.maskCache.get(sig); cached != nil {
|
||||
condition := mlx.GreaterEqual(cached, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Compute valid tokens using candidate filtering:
|
||||
// 1. Get valid terminal IDs from current grammar state
|
||||
// 2. Get candidate tokens (those that START with valid terminals)
|
||||
// 3. Run DP validation only on candidates
|
||||
// This is O(candidates) instead of O(vocab_size)
|
||||
|
||||
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
|
||||
|
||||
// Use pre-partitioned token groups for fast candidate building
|
||||
// This eliminates per-token branching - just direct slice appends
|
||||
e.validTokenIDs = e.validTokenIDs[:0]
|
||||
e.dpCandidates = e.dpCandidates[:0]
|
||||
e.touched = e.touched[:0]
|
||||
|
||||
for _, tid := range validTerminalIDs {
|
||||
groups := e.analyzer.terminalGroups(tid)
|
||||
|
||||
// Direct append of exact matches (no per-token check needed)
|
||||
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
|
||||
|
||||
// Collect DP candidates (may have duplicates across terminals)
|
||||
for _, tokenID := range groups.DPCandidates {
|
||||
if !e.candidateMark[tokenID] {
|
||||
e.candidateMark[tokenID] = true
|
||||
e.dpCandidates = append(e.dpCandidates, tokenID)
|
||||
e.touched = append(e.touched, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset marks for next call
|
||||
for _, id := range e.touched {
|
||||
e.candidateMark[id] = false
|
||||
}
|
||||
|
||||
for _, tokenID := range e.dpCandidates {
|
||||
if e.bridge.IsTokenValid(tokenID, e.configSet) {
|
||||
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
|
||||
}
|
||||
}
|
||||
|
||||
// Create and cache the mask on GPU using index updates
|
||||
mask := mlx.Zeros([]int32{e.vocabSize})
|
||||
if len(e.validTokenIDs) > 0 {
|
||||
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
|
||||
values := mlx.Ones(int32(len(e.validTokenIDs)))
|
||||
mask = mlx.PutAlongAxis(mask, indices, values, 0)
|
||||
}
|
||||
mlx.Keep(mask)
|
||||
|
||||
// Cache by state signature
|
||||
e.maskCache.put(sig, mask)
|
||||
|
||||
// Apply mask
|
||||
condition := mlx.GreaterEqual(mask, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Accept processes a token and updates the parser state.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) Accept(tokenID int) bool {
|
||||
if tokenID < 0 || tokenID >= len(e.vocab) {
|
||||
return false
|
||||
}
|
||||
|
||||
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
|
||||
if newConfig == nil {
|
||||
return false
|
||||
}
|
||||
e.configSet = newConfig
|
||||
return true
|
||||
}
|
||||
|
||||
// AcceptString processes a token string directly.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) AcceptString(token string) bool {
|
||||
if id, ok := e.tokenToID[token]; ok {
|
||||
return e.Accept(id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsComplete returns true if the current state is accepting.
|
||||
func (e *Engine) IsComplete() bool {
|
||||
return e.bridge.isAccepting(e.configSet)
|
||||
}
|
||||
|
||||
// Reset resets the engine to initial state.
|
||||
func (e *Engine) Reset() {
|
||||
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
|
||||
}
|
||||
|
||||
// validTokens returns the indices of tokens that are currently valid.
|
||||
func (e *Engine) validTokens() []int {
|
||||
return e.bridge.validTokens(e.configSet)
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the current state.
|
||||
func (e *Engine) validTerminals() []string {
|
||||
return e.bridge.validTerminals(e.configSet)
|
||||
}
|
||||
|
||||
// Close releases MLX resources.
|
||||
func (e *Engine) Close() {
|
||||
if e.maskCache != nil {
|
||||
e.maskCache.clear()
|
||||
}
|
||||
if e.negInfMask != nil {
|
||||
e.negInfMask.Free()
|
||||
}
|
||||
if e.threshold != nil {
|
||||
e.threshold.Free()
|
||||
}
|
||||
}
|
||||
@@ -1,414 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newBenchEngine creates a JSON engine for benchmarks
|
||||
func newBenchEngine(b *testing.B, vocab []string) *Engine {
|
||||
b.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Vocabulary sizes to test (matching real models)
|
||||
var vocabSizes = []int{
|
||||
32000, // Llama 2
|
||||
128000, // Llama 3
|
||||
256000, // Large models
|
||||
}
|
||||
|
||||
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
|
||||
func createBenchVocabN(n int) []string {
|
||||
vocab := make([]string, n)
|
||||
|
||||
// JSON structural tokens (first 20)
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"", "'",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < n {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// String tokens (indices 20-1000)
|
||||
stringIdx := 20
|
||||
for i := 0; i < 980 && stringIdx+i < n; i++ {
|
||||
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
|
||||
}
|
||||
|
||||
// Number tokens (indices 1000-2000)
|
||||
numberIdx := 1000
|
||||
for i := 0; i < 1000 && numberIdx+i < n; i++ {
|
||||
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
|
||||
// Generic tokens (rest)
|
||||
for i := 2000; i < n; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
|
||||
// ============ Core Performance Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMask_32k measures mask application with 32k vocab
|
||||
func BenchmarkApplyMask_32k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 32000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_128k measures mask application with 128k vocab
|
||||
func BenchmarkApplyMask_128k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 128000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_256k measures mask application with 256k vocab
|
||||
func BenchmarkApplyMask_256k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 256000)
|
||||
}
|
||||
|
||||
func benchmarkApplyMask(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(vocabSize), "vocab_size")
|
||||
}
|
||||
|
||||
// ============ state-Dependent Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
|
||||
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
e.AcceptString("{")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkApplyMaskMidObject measures mask in middle of object
|
||||
func BenchmarkApplyMaskMidObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// state: {"key": _value_
|
||||
e.AcceptString("{")
|
||||
e.AcceptString("\"key\"")
|
||||
e.AcceptString(":")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Token Sequence Benchmarks ============
|
||||
|
||||
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
|
||||
func BenchmarkSequence_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
|
||||
func BenchmarkSequence_NestedObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
|
||||
func BenchmarkSequence_LargeArray(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Build sequence: [1, 2, 3, ..., 50]
|
||||
sequence := []string{"["}
|
||||
for i := 1; i <= 50; i++ {
|
||||
sequence = append(sequence, fmt.Sprintf("%d", i))
|
||||
if i < 50 {
|
||||
sequence = append(sequence, ",")
|
||||
}
|
||||
}
|
||||
sequence = append(sequence, "]")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
|
||||
func BenchmarkSequence_MixedTypes(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{",
|
||||
"\"name\"", ":", "\"test\"", ",",
|
||||
"\"count\"", ":", "42", ",",
|
||||
"\"enabled\"", ":", "true", ",",
|
||||
"\"data\"", ":", "null", ",",
|
||||
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
|
||||
"}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// ============ Component Benchmarks ============
|
||||
|
||||
// BenchmarkValidInputs measures pda valid input computation
|
||||
func BenchmarkValidInputs(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.validTerminals()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateTransition measures pda state transition
|
||||
func BenchmarkStateTransition(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
|
||||
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // Graph only, no eval
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNewEngine measures one-time engine initialization.
|
||||
func BenchmarkNewEngine_32k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkNewEngine_128k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkNewEngine(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newBenchEngine(b, vocab)
|
||||
e.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Memory Benchmarks ============
|
||||
|
||||
func BenchmarkMemoryAllocs_32k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkMemoryAllocs_128k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
|
||||
|
||||
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
|
||||
// This simulates adding mask to LLM compute graph
|
||||
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // No Eval - just build graph
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
|
||||
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
var lastMasked *mlx.Array
|
||||
for _, token := range sequence {
|
||||
lastMasked = e.ApplyMask(logits) // Build graph only
|
||||
e.AcceptString(token)
|
||||
}
|
||||
mlx.Eval(lastMasked) // Single eval at end
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newTestEngine creates a JSON engine for testing
|
||||
func newTestEngine(t testing.TB, vocab []string) *Engine {
|
||||
t.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Mock vocabulary for testing
|
||||
func testVocab() []string {
|
||||
return []string{
|
||||
"{", // 0: object start
|
||||
"}", // 1: object end
|
||||
"[", // 2: array start
|
||||
"]", // 3: array end
|
||||
":", // 4: colon
|
||||
",", // 5: comma
|
||||
"\"key\"", // 6: string (quoted)
|
||||
"\"val\"", // 7: string (quoted)
|
||||
"123", // 8: number
|
||||
"-42.5", // 9: number
|
||||
"true", // 10: boolean
|
||||
"false", // 11: boolean
|
||||
"null", // 12: null
|
||||
" ", // 13: whitespace (should be ignored)
|
||||
"\n", // 14: whitespace (should be ignored)
|
||||
"subword", // 15: bare word (NOT valid JSON - requires quotes)
|
||||
"hello", // 16: bare word (NOT valid JSON - requires quotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != int32(len(vocab)) {
|
||||
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
|
||||
}
|
||||
|
||||
// Verify grammar is set
|
||||
if e.grammar == nil {
|
||||
t.Error("grammar should not be nil")
|
||||
}
|
||||
|
||||
// Verify analyzer is set
|
||||
if e.analyzer == nil {
|
||||
t.Error("analyzer should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineValidTokens(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// At start, any value type should be valid
|
||||
validTokens := e.validTokens()
|
||||
|
||||
// Should include object start, array start, strings, numbers, booleans, null
|
||||
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
|
||||
// (JSON strings must be quoted)
|
||||
expectedTokens := map[int]bool{
|
||||
0: true, // {
|
||||
2: true, // [
|
||||
6: true, // "key"
|
||||
7: true, // "val"
|
||||
8: true, // 123
|
||||
9: true, // -42.5
|
||||
10: true, // true
|
||||
11: true, // false
|
||||
12: true, // null
|
||||
}
|
||||
|
||||
// Check that expected tokens are present
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
for idx := range expectedTokens {
|
||||
if !validSet[idx] {
|
||||
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if validSet[15] || validSet[16] {
|
||||
t.Error("bare words should not be valid JSON at the start state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAccept(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { should work
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
|
||||
// After {, valid tokens should be STRING or }
|
||||
validTokens := e.validTokens()
|
||||
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
// STRING tokens (indices 6, 7) and } (index 1) should be valid
|
||||
if !validSet[1] {
|
||||
t.Error("} should be valid after {")
|
||||
}
|
||||
if !validSet[6] && !validSet[7] {
|
||||
t.Error("STRING should be valid after { (for keys)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptSequence(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept {"key": "val"}
|
||||
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
|
||||
|
||||
for i, tokenID := range sequence {
|
||||
if !e.Accept(tokenID) {
|
||||
t.Fatalf("failed to accept token %d (%s) at position %d",
|
||||
tokenID, vocab[tokenID], i)
|
||||
}
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be in complete state after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineReset(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept some tokens
|
||||
e.Accept(0) // {
|
||||
e.Accept(1) // }
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
|
||||
// Reset
|
||||
e.Reset()
|
||||
|
||||
// Should be back to initial state
|
||||
if e.IsComplete() {
|
||||
t.Error("should not be complete after reset")
|
||||
}
|
||||
|
||||
// Should be able to accept new sequence
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept { after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineInvalidTokenRejection(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { first
|
||||
if !e.Accept(0) {
|
||||
t.Fatal("should accept {")
|
||||
}
|
||||
|
||||
// Now try to accept [ which is invalid after {
|
||||
// (After {, only STRING or } are valid)
|
||||
if e.Accept(2) { // [
|
||||
t.Error("should not accept [ after { (expecting STRING or })")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptString(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept using string directly
|
||||
if !e.AcceptString("{") {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.AcceptString("\"key\"") {
|
||||
t.Error("should accept string key")
|
||||
}
|
||||
if !e.AcceptString(":") {
|
||||
t.Error("should accept :")
|
||||
}
|
||||
if !e.AcceptString("123") {
|
||||
t.Error("should accept number")
|
||||
}
|
||||
if !e.AcceptString("}") {
|
||||
t.Error("should accept }")
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBackslashEscape(t *testing.T) {
|
||||
vocab := []string{`"`, `\`, "n", "a"}
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Valid escape: "\n"
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if !e.AcceptString("n") {
|
||||
t.Fatal("should accept escape code")
|
||||
}
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string end")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after escaped string")
|
||||
}
|
||||
|
||||
// Invalid escape: "\a"
|
||||
e.Reset()
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if e.AcceptString("a") {
|
||||
t.Error("should reject invalid escape code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineNegInfMask(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Verify negInfMask exists and has correct shape
|
||||
if e.negInfMask == nil {
|
||||
t.Fatal("negInfMask should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMaskCache(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Create test logits
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
|
||||
// Apply mask - should populate cache
|
||||
_ = e.ApplyMask(logits)
|
||||
|
||||
// Check cache was populated
|
||||
cacheSize := e.maskCache.size()
|
||||
if cacheSize == 0 {
|
||||
t.Error("mask cache should have at least one entry after ApplyMask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEmptyVocab(t *testing.T) {
|
||||
e := newTestEngine(t, []string{})
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 0 {
|
||||
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineLargeVocab(t *testing.T) {
|
||||
// Create a large vocabulary (simulating real model vocab)
|
||||
vocab := make([]string, 32000)
|
||||
for i := range vocab {
|
||||
vocab[i] = "token"
|
||||
}
|
||||
// Add some actual JSON tokens
|
||||
vocab[0] = "{"
|
||||
vocab[1] = "}"
|
||||
vocab[2] = "["
|
||||
vocab[3] = "]"
|
||||
vocab[4] = ":"
|
||||
vocab[5] = ","
|
||||
vocab[6] = "\"test\""
|
||||
vocab[7] = "123"
|
||||
vocab[8] = "true"
|
||||
vocab[9] = "false"
|
||||
vocab[10] = "null"
|
||||
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 32000 {
|
||||
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
|
||||
}
|
||||
|
||||
// Test that it still works correctly
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.Accept(1) { // }
|
||||
t.Error("should accept }")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
|
||||
func TestE2E_JSONDecoding(t *testing.T) {
|
||||
// Create a realistic vocabulary with JSON tokens
|
||||
vocab := []string{
|
||||
// Structural tokens
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
// Keywords
|
||||
"true", "false", "null",
|
||||
// Quoted strings
|
||||
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
|
||||
`"hello"`, `"world"`, `"test"`,
|
||||
// Numbers
|
||||
"0", "1", "2", "3", "42", "123", "-1", "-42",
|
||||
// Whitespace
|
||||
" ", "\n", "\t",
|
||||
// Multi-terminal tokens (span multiple JSON lexemes)
|
||||
`"key":`, `},`, `],`, `{"`, `["`,
|
||||
// Partial/invalid tokens (should be rejected)
|
||||
"invalid", "foo", "bar",
|
||||
}
|
||||
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
// Simple values
|
||||
{"empty object", []string{"{", "}"}, true},
|
||||
{"empty array", []string{"[", "]"}, true},
|
||||
{"true literal", []string{"true"}, true},
|
||||
{"null literal", []string{"null"}, true},
|
||||
{"number", []string{"42"}, true},
|
||||
{"negative number", []string{"-42"}, true},
|
||||
{"quoted string", []string{`"hello"`}, true},
|
||||
|
||||
// Objects
|
||||
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
|
||||
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
|
||||
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
|
||||
|
||||
// Arrays
|
||||
{"array of numbers", []string{"[", "42", "]"}, true},
|
||||
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
|
||||
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
|
||||
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
|
||||
|
||||
// Nested structures
|
||||
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
|
||||
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
|
||||
|
||||
// Invalid sequences
|
||||
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
|
||||
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
|
||||
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
|
||||
{"bare word", []string{"invalid"}, false}, // not valid JSON
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
// Process each token
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return // Already reported error
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
// For invalid sequences, we expect either rejection or incomplete
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
|
||||
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
|
||||
// Simple expression grammar: expr = term { ("+" | "-") term }
|
||||
// term = number | "(" expr ")"
|
||||
// number = digit { digit }
|
||||
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
|
||||
exprGrammar := `
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(exprGrammar, "expr")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary for expression tokens
|
||||
vocab := []string{
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"+", "-", "*", "/",
|
||||
"(", ")",
|
||||
// Multi-digit numbers as single tokens
|
||||
"10", "42", "100", "123",
|
||||
// Invalid tokens
|
||||
"x", "y", "invalid",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single digit", []string{"5"}, true},
|
||||
{"multi-digit", []string{"1", "2", "3"}, true},
|
||||
{"addition", []string{"1", "+", "2"}, true},
|
||||
{"subtraction", []string{"5", "-", "3"}, true},
|
||||
{"multiplication", []string{"2", "*", "3"}, true},
|
||||
{"division", []string{"8", "/", "2"}, true},
|
||||
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
|
||||
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
|
||||
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
|
||||
|
||||
// Invalid
|
||||
{"just operator", []string{"+"}, false},
|
||||
{"double operator", []string{"1", "+", "+", "2"}, false},
|
||||
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
|
||||
{"variable", []string{"x"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
|
||||
func TestE2E_IdentifierGrammar(t *testing.T) {
|
||||
// Identifier grammar using character ranges
|
||||
identGrammar := `
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(identGrammar, "ident")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse identifier grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with letters and digits
|
||||
vocab := []string{
|
||||
"a", "b", "c", "x", "y", "z",
|
||||
"A", "B", "C", "X", "Y", "Z",
|
||||
"_",
|
||||
"0", "1", "2", "9",
|
||||
// Multi-char tokens
|
||||
"foo", "bar", "myVar", "test123",
|
||||
// Invalid starting chars
|
||||
"1abc", "123",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single letter", []string{"a"}, true},
|
||||
{"uppercase", []string{"A"}, true},
|
||||
{"underscore", []string{"_"}, true},
|
||||
{"multi-letter", []string{"a", "b", "c"}, true},
|
||||
{"letter then digit", []string{"x", "1"}, true},
|
||||
{"underscore prefix", []string{"_", "a", "1"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass && allAccepted && !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
|
||||
func TestE2E_UnicodeRange(t *testing.T) {
|
||||
greekGrammar := `
|
||||
greek = "α" … "ω" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(greekGrammar, "greek")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse unicode grammar: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{"α", "β", "ω", "a"}
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
if !engine.AcceptString("β") {
|
||||
t.Error("should accept beta")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete after single rune")
|
||||
}
|
||||
|
||||
engine.Reset()
|
||||
if engine.AcceptString("a") {
|
||||
t.Error("should reject ASCII outside unicode range")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
|
||||
func TestE2E_NondeterminismPreserved(t *testing.T) {
|
||||
// This grammar has nondeterminism: "ab" could be parsed as
|
||||
// a single token or as two tokens "a" "b"
|
||||
ambiguousGrammar := `
|
||||
start = item item .
|
||||
item = "a" | "b" | "ab" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(ambiguousGrammar, "start")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with both single and combined tokens
|
||||
vocab := []string{"a", "b", "ab"}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Test: "ab" "a" should be valid (ab as first item, a as second)
|
||||
t.Run("ab then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a after ab")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then ab", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a")
|
||||
}
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab after a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept first a")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept second a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,614 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package grammar provides GPU-accelerated constrained decoding using MLX.
|
||||
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
|
||||
// For JSON Schema conversion, see the grammar/schema subpackage.
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/ebnf"
|
||||
)
|
||||
|
||||
// stackSymbol represents a symbol that can be pushed onto the pda stack.
|
||||
type stackSymbol int
|
||||
|
||||
const (
|
||||
stackEmpty stackSymbol = iota
|
||||
// Additional stack symbols will be generated per-grammar
|
||||
)
|
||||
|
||||
// state represents a pda state.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateError state = -1
|
||||
stateStart state = 0
|
||||
stateAccept state = 1
|
||||
// Additional states will be generated per-grammar
|
||||
)
|
||||
|
||||
// transition represents a pda transition.
|
||||
// On input matching Pattern, from FromState with stackTop:
|
||||
// - Move to ToState
|
||||
// - Pop StackPop symbols, push StackPush symbols
|
||||
type transition struct {
|
||||
FromState state
|
||||
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
|
||||
Pattern string // Input pattern to match (token or character class)
|
||||
ToState state
|
||||
StackPop int // Number of symbols to pop
|
||||
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
|
||||
}
|
||||
|
||||
// pda represents a compiled pushdown automaton.
|
||||
type pda struct {
|
||||
States int // Total number of states
|
||||
StackSymbols int // Total number of stack symbols
|
||||
StartState state // Initial state
|
||||
AcceptStates map[state]bool // Set of accepting states
|
||||
Transitions map[state][]transition // Transitions indexed by from-state
|
||||
|
||||
// For token-level matching
|
||||
Terminals []string // All terminal symbols (patterns to match)
|
||||
}
|
||||
|
||||
// newPDA creates an empty pda.
|
||||
func newPDA() *pda {
|
||||
return &pda{
|
||||
States: 2, // Error and Start
|
||||
StackSymbols: 1, // Empty
|
||||
StartState: stateStart,
|
||||
AcceptStates: make(map[state]bool),
|
||||
Transitions: make(map[state][]transition),
|
||||
Terminals: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// addState adds a new state and returns its ID.
|
||||
func (p *pda) addState() state {
|
||||
s := state(p.States)
|
||||
p.States++
|
||||
return s
|
||||
}
|
||||
|
||||
// addStackSymbol adds a new stack symbol and returns its ID.
|
||||
func (p *pda) addStackSymbol() stackSymbol {
|
||||
s := stackSymbol(p.StackSymbols)
|
||||
p.StackSymbols++
|
||||
return s
|
||||
}
|
||||
|
||||
// addTransition adds a transition to the pda.
|
||||
func (p *pda) addTransition(t transition) {
|
||||
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
|
||||
}
|
||||
|
||||
// addTerminal registers a terminal pattern and returns its index.
|
||||
func (p *pda) addTerminal(pattern string) int {
|
||||
for i, t := range p.Terminals {
|
||||
if t == pattern {
|
||||
return i
|
||||
}
|
||||
}
|
||||
p.Terminals = append(p.Terminals, pattern)
|
||||
return len(p.Terminals) - 1
|
||||
}
|
||||
|
||||
// compiler compiles EBNF grammars to PDAs.
|
||||
type compiler struct {
|
||||
grammar ebnf.Grammar
|
||||
pda *pda
|
||||
|
||||
// Maps production names to their entry/exit states
|
||||
prodEntry map[string]state
|
||||
prodExit map[string]state
|
||||
}
|
||||
|
||||
// compile parses an EBNF grammar and compiles it to a pda.
|
||||
func compile(name string, src io.Reader, start string) (*pda, error) {
|
||||
grammar, err := ebnf.Parse(name, src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse grammar: %w", err)
|
||||
}
|
||||
|
||||
if err := ebnf.Verify(grammar, start); err != nil {
|
||||
return nil, fmt.Errorf("verify grammar: %w", err)
|
||||
}
|
||||
|
||||
c := &compiler{
|
||||
grammar: grammar,
|
||||
pda: newPDA(),
|
||||
prodEntry: make(map[string]state),
|
||||
prodExit: make(map[string]state),
|
||||
}
|
||||
|
||||
// Create entry/exit states for each production
|
||||
for name := range grammar {
|
||||
c.prodEntry[name] = c.pda.addState()
|
||||
c.prodExit[name] = c.pda.addState()
|
||||
}
|
||||
|
||||
// compile each production
|
||||
for name, prod := range grammar {
|
||||
if err := c.compileProduction(name, prod); err != nil {
|
||||
return nil, fmt.Errorf("compile production %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set start state to entry of start production
|
||||
if entry, ok := c.prodEntry[start]; ok {
|
||||
// Add epsilon transition from pda start to grammar start
|
||||
c.pda.addTransition(transition{
|
||||
FromState: stateStart,
|
||||
Pattern: "", // epsilon
|
||||
ToState: entry,
|
||||
})
|
||||
} else {
|
||||
return nil, fmt.Errorf("start production %q not found", start)
|
||||
}
|
||||
|
||||
// Mark exit of start production as accepting
|
||||
if exit, ok := c.prodExit[start]; ok {
|
||||
c.pda.AcceptStates[exit] = true
|
||||
}
|
||||
|
||||
return c.pda, nil
|
||||
}
|
||||
|
||||
// compileString is a convenience function to compile from a string.
|
||||
func compileString(grammar string, start string) (*pda, error) {
|
||||
return compile("grammar", strings.NewReader(grammar), start)
|
||||
}
|
||||
|
||||
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
|
||||
entry := c.prodEntry[name]
|
||||
exit := c.prodExit[name]
|
||||
|
||||
return c.compileExpr(prod.Expr, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
|
||||
switch e := expr.(type) {
|
||||
case *ebnf.Name:
|
||||
return c.compileName(e, entry, exit)
|
||||
case *ebnf.Token:
|
||||
return c.compileToken(e, entry, exit)
|
||||
case ebnf.Sequence:
|
||||
return c.compileSequence(e, entry, exit)
|
||||
case ebnf.Alternative:
|
||||
return c.compileAlternative(e, entry, exit)
|
||||
case *ebnf.Option:
|
||||
return c.compileOption(e, entry, exit)
|
||||
case *ebnf.Repetition:
|
||||
return c.compileRepetition(e, entry, exit)
|
||||
case *ebnf.Group:
|
||||
return c.compileExpr(e.Body, entry, exit)
|
||||
case *ebnf.Range:
|
||||
return c.compileRange(e, entry, exit)
|
||||
case nil:
|
||||
// Empty production - direct epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
|
||||
// Reference to another production
|
||||
prodName := n.String
|
||||
|
||||
prodEntry, ok := c.prodEntry[prodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined production: %s", prodName)
|
||||
}
|
||||
prodExit := c.prodExit[prodName]
|
||||
// Use a unique stack symbol per call site so returns are unambiguous.
|
||||
stackSym := c.pda.addStackSymbol()
|
||||
|
||||
// Push return address, go to production entry
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "", // epsilon
|
||||
ToState: prodEntry,
|
||||
StackPush: []stackSymbol{stackSym},
|
||||
})
|
||||
|
||||
// On production exit, pop and return
|
||||
c.pda.addTransition(transition{
|
||||
FromState: prodExit,
|
||||
stackTop: stackSym,
|
||||
Pattern: "", // epsilon
|
||||
ToState: exit,
|
||||
StackPop: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
|
||||
// terminal symbol - add transition that consumes this token
|
||||
pattern := t.String
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
|
||||
if len(seq) == 0 {
|
||||
// Empty sequence - epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain: entry -> s1 -> s2 -> ... -> exit
|
||||
current := entry
|
||||
for i, expr := range seq {
|
||||
var next state
|
||||
if i == len(seq)-1 {
|
||||
next = exit
|
||||
} else {
|
||||
next = c.pda.addState()
|
||||
}
|
||||
|
||||
if err := c.compileExpr(expr, current, next); err != nil {
|
||||
return err
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
|
||||
// Each alternative goes from entry to exit
|
||||
for _, expr := range alt {
|
||||
if err := c.compileExpr(expr, entry, exit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
|
||||
// Optional: can skip (epsilon) or take the body
|
||||
|
||||
// Epsilon transition (skip)
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Or take the body
|
||||
return c.compileExpr(opt.Body, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
|
||||
// Repetition {body}: zero or more
|
||||
// entry -> exit (skip)
|
||||
// entry -> body -> entry (loop back)
|
||||
|
||||
// Skip transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Loop: entry -> (body) -> entry
|
||||
return c.compileExpr(rep.Body, entry, entry)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
|
||||
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
|
||||
begin := strings.Trim(r.Begin.String, "\"")
|
||||
end := strings.Trim(r.End.String, "\"")
|
||||
|
||||
// Unescape bounds first (so "\u03b1" works)
|
||||
beginUnesc, err := unescapeLiteral(begin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range begin: %w", err)
|
||||
}
|
||||
endUnesc, err := unescapeLiteral(end)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range end: %w", err)
|
||||
}
|
||||
|
||||
// Validate as single runes (not bytes) for Unicode support
|
||||
beginRunes := []rune(beginUnesc)
|
||||
endRunes := []rune(endUnesc)
|
||||
if len(beginRunes) != 1 || len(endRunes) != 1 {
|
||||
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
|
||||
}
|
||||
|
||||
// Use unescaped rune strings in pattern (consistent with matcher)
|
||||
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtime represents a pda execution instance.
|
||||
type runtime struct {
|
||||
pda *pda
|
||||
state state
|
||||
stack []stackSymbol
|
||||
}
|
||||
|
||||
// newRuntime creates a new pda runtime.
|
||||
func newRuntime(pda *pda) *runtime {
|
||||
return &runtime{
|
||||
pda: pda,
|
||||
state: pda.StartState,
|
||||
stack: make([]stackSymbol, 0, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// stackTop returns the top of the stack, or stackEmpty if empty.
|
||||
func (r *runtime) stackTop() stackSymbol {
|
||||
if len(r.stack) == 0 {
|
||||
return stackEmpty
|
||||
}
|
||||
return r.stack[len(r.stack)-1]
|
||||
}
|
||||
|
||||
// isAccepting returns true if we can reach an accepting state via epsilon transitions
|
||||
// with an empty stack.
|
||||
func (r *runtime) isAccepting() bool {
|
||||
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if r.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
for _, t := range r.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
|
||||
// Check stack constraint
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simulate stack operations
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 && len(newStack) >= t.StackPop {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if r.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset resets the runtime to initial state.
|
||||
func (r *runtime) Reset() {
|
||||
r.state = r.pda.StartState
|
||||
r.stack = r.stack[:0]
|
||||
}
|
||||
|
||||
// validInputs returns all valid input patterns from current state.
|
||||
func (r *runtime) validInputs() []string {
|
||||
var valid []string
|
||||
seen := make(map[string]bool)
|
||||
visited := make(map[stateStackKey]bool)
|
||||
|
||||
// Make a copy of the stack for simulation
|
||||
simStack := make([]stackSymbol, len(r.stack))
|
||||
copy(simStack, r.stack)
|
||||
|
||||
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
|
||||
return valid
|
||||
}
|
||||
|
||||
// stateStackKey is used to detect cycles in epsilon closure
|
||||
type stateStackKey struct {
|
||||
state state
|
||||
stackSig string
|
||||
}
|
||||
|
||||
func stackSignature(stack []stackSymbol) string {
|
||||
if len(stack) == 0 {
|
||||
return ""
|
||||
}
|
||||
buf := make([]byte, len(stack)*8)
|
||||
for i, sym := range stack {
|
||||
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
|
||||
// Get stack top for comparisons
|
||||
stackTop := stackEmpty
|
||||
if len(simStack) > 0 {
|
||||
stackTop = simStack[len(simStack)-1]
|
||||
}
|
||||
|
||||
// Check for cycles to avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[state]
|
||||
|
||||
for _, t := range transitions {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - simulate stack operations
|
||||
newStack := make([]stackSymbol, len(simStack))
|
||||
copy(newStack, simStack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
if len(newStack) < t.StackPop {
|
||||
continue // Can't pop, skip this transition
|
||||
}
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
|
||||
} else {
|
||||
// terminal - add if not seen
|
||||
if !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*valid = append(*valid, t.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchesPattern checks if input matches a pattern.
|
||||
// Patterns can be:
|
||||
// - Exact strings: "a", "{", "true"
|
||||
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
|
||||
func matchesPattern(input, pattern string) bool {
|
||||
// Exact match
|
||||
if input == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for character range pattern [X-Y]
|
||||
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
|
||||
if len(input) != 1 {
|
||||
return false
|
||||
}
|
||||
ch := input[0]
|
||||
low := pattern[1]
|
||||
high := pattern[3]
|
||||
return ch >= low && ch <= high
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Accept tries to accept an input, returning true if successful.
|
||||
func (r *runtime) Accept(input string) bool {
|
||||
return r.accept(input, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[r.state]
|
||||
|
||||
// First, process any epsilon transitions to reach a state that can accept input
|
||||
// This is a simplified version - full implementation would need epsilon closure
|
||||
for _, t := range transitions {
|
||||
if matchesPattern(input, t.Pattern) {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply transition
|
||||
r.applyTransition(t)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try epsilon transitions first
|
||||
for _, t := range transitions {
|
||||
if t.Pattern == "" {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Save state for backtracking
|
||||
oldState := r.state
|
||||
oldStack := make([]stackSymbol, len(r.stack))
|
||||
copy(oldStack, r.stack)
|
||||
|
||||
r.applyTransition(t)
|
||||
|
||||
if r.accept(input, visited) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Backtrack
|
||||
r.state = oldState
|
||||
r.stack = oldStack
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *runtime) applyTransition(t transition) {
|
||||
// Pop
|
||||
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
|
||||
r.stack = r.stack[:len(r.stack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
r.stack = append(r.stack, t.StackPush...)
|
||||
|
||||
// Move to new state
|
||||
r.state = t.ToState
|
||||
}
|
||||
@@ -1,540 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileSimpleGrammar(t *testing.T) {
|
||||
// Simple grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
if pda == nil {
|
||||
t.Fatal("pda is nil")
|
||||
}
|
||||
|
||||
// Should have terminals "a" and "b"
|
||||
if len(pda.Terminals) != 2 {
|
||||
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
|
||||
}
|
||||
|
||||
// Test runtime
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Should accept "a" then "b"
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be in accepting state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileAlternative(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Test accepting "a"
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'a'")
|
||||
}
|
||||
|
||||
// Test accepting "b"
|
||||
rt.Reset()
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// Test rejecting "c"
|
||||
rt.Reset()
|
||||
if rt.Accept("c") {
|
||||
t.Error("should not accept 'c'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRepetition(t *testing.T) {
|
||||
// Grammar: S = {"a"} .
|
||||
grammar := `S = {"a"} .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty should be accepted (zero repetitions)
|
||||
rt := newRuntime(pda)
|
||||
if !rt.isAccepting() {
|
||||
t.Error("empty should be accepting")
|
||||
}
|
||||
|
||||
// "a" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept first 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after one 'a'")
|
||||
}
|
||||
|
||||
// "aa" should be accepted
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept second 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after two 'a's")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileOption(t *testing.T) {
|
||||
// Grammar: S = ["a"] "b" .
|
||||
grammar := `S = ["a"] "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "b" alone should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' alone")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// "ab" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' after 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'ab'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRecursive(t *testing.T) {
|
||||
// Grammar with recursion: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "x" should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'x'")
|
||||
}
|
||||
|
||||
// "(x)" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x' inside parens")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x)'")
|
||||
}
|
||||
|
||||
// "((x))" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept first '('")
|
||||
}
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept second '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept first ')'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept second ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '((x))'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidInputs(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
valid := rt.validInputs()
|
||||
|
||||
// Should have both "a" and "b" as valid
|
||||
hasA, hasB := false, false
|
||||
for _, v := range valid {
|
||||
if v == "a" {
|
||||
hasA = true
|
||||
}
|
||||
if v == "b" {
|
||||
hasB = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasA {
|
||||
t.Error("'a' should be valid input")
|
||||
}
|
||||
if !hasB {
|
||||
t.Error("'b' should be valid input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsAfterAccept tests that validInputs returns correct values
|
||||
// after accepting tokens, ensuring proper stack simulation.
|
||||
func TestValidInputsAfterAccept(t *testing.T) {
|
||||
// Grammar: S = "a" "b" "c" .
|
||||
grammar := `S = "a" "b" "c" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "a" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "a" {
|
||||
t.Errorf("initially expected only 'a', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "a", only "b" should be valid
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "b" {
|
||||
t.Errorf("after 'a', expected only 'b', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "b", only "c" should be valid
|
||||
if !rt.Accept("b") {
|
||||
t.Fatal("failed to accept 'b'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "c" {
|
||||
t.Errorf("after 'ab', expected only 'c', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsWithRepetitionInProduction tests the critical case where
|
||||
// a repetition exists inside a called production. This requires proper
|
||||
// stack simulation to determine when closing symbols are valid.
|
||||
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
|
||||
// Grammar similar to JSON:
|
||||
// S = "(" items ")" .
|
||||
// items = item { "," item } .
|
||||
// item = "x" .
|
||||
grammar := `
|
||||
S = "(" items ")" .
|
||||
items = item { "," item } .
|
||||
item = "x" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "(" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "(" {
|
||||
t.Errorf("initially expected only '(', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept "("
|
||||
if !rt.Accept("(") {
|
||||
t.Fatal("failed to accept '('")
|
||||
}
|
||||
// After "(", should be able to accept "x" (item)
|
||||
valid = rt.validInputs()
|
||||
hasX := false
|
||||
for _, v := range valid {
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasX {
|
||||
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept first item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept 'x'")
|
||||
}
|
||||
// After "(x", should be able to accept "," (more items) OR ")" (end)
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose := false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept comma for another item
|
||||
if !rt.Accept(",") {
|
||||
t.Fatal("failed to accept ','")
|
||||
}
|
||||
// After "(x,", should only be able to accept "x" (next item)
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "x" {
|
||||
t.Errorf("after '(x,', expected only 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept second item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept second 'x'")
|
||||
}
|
||||
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
|
||||
// This tests the stack simulation fix - we need to properly
|
||||
// follow epsilon transitions through the production call stack.
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose = false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Close with ")"
|
||||
if !rt.Accept(")") {
|
||||
t.Fatal("failed to accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x,x)'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
|
||||
func TestValidInputsNestedCalls(t *testing.T) {
|
||||
// Grammar: A = "start" B "end" . B = "middle" .
|
||||
grammar := `
|
||||
A = "start" B "end" .
|
||||
B = "middle" .
|
||||
`
|
||||
pda, err := compileString(grammar, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// After "start", should accept "middle" (from B)
|
||||
rt.Accept("start")
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "middle" {
|
||||
t.Errorf("after 'start', expected 'middle', got %v", valid)
|
||||
}
|
||||
|
||||
// After "start middle", should accept "end"
|
||||
rt.Accept("middle")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "end" {
|
||||
t.Errorf("after 'start middle', expected 'end', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnAddressDisambiguation(t *testing.T) {
|
||||
// Grammar where the same production is called from different contexts:
|
||||
// S = A "x" | "c" A "y" .
|
||||
// A = "a" .
|
||||
grammar := `
|
||||
S = A "x" | "c" A "y" .
|
||||
A = "a" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
if !rt.Accept("c") {
|
||||
t.Fatal("failed to accept 'c'")
|
||||
}
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "y" {
|
||||
t.Errorf("after 'ca', expected only 'y', got %v", valid)
|
||||
}
|
||||
|
||||
rt.Reset()
|
||||
rt.Accept("c")
|
||||
rt.Accept("a")
|
||||
if rt.Accept("x") {
|
||||
t.Error("should not accept 'x' after 'ca'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
|
||||
// which heavily exercise the stack simulation.
|
||||
func TestValidInputsRecursiveWithStack(t *testing.T) {
|
||||
// Grammar: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially: "(" or "x" should be valid
|
||||
valid := rt.validInputs()
|
||||
hasParen, hasX := false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("initially expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "(": "(" or "x" should be valid (nested S)
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((": "(" or "x" should still be valid
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x": only ")" should be valid
|
||||
rt.Accept("x")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x', expected only ')', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x)": only ")" should be valid (closing outer)
|
||||
rt.Accept(")")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x)', expected only ')', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRejectionAfterValid tests that invalid inputs are rejected
|
||||
// at various points in the grammar.
|
||||
func TestRejectionAfterValid(t *testing.T) {
|
||||
// Grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// "b" should be rejected initially
|
||||
if rt.Accept("b") {
|
||||
t.Error("'b' should be rejected initially")
|
||||
}
|
||||
|
||||
// Accept "a"
|
||||
rt.Accept("a")
|
||||
|
||||
// "a" should be rejected after "a"
|
||||
if rt.Accept("a") {
|
||||
t.Error("'a' should be rejected after 'a'")
|
||||
}
|
||||
|
||||
// "c" should be rejected (not in grammar)
|
||||
if rt.Accept("c") {
|
||||
t.Error("'c' should be rejected (not in grammar)")
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
# Example Grammars
|
||||
|
||||
This directory contains example EBNF grammars for constrained decoding.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
go run -tags mlx ./x/imagegen/cmd/engine/ \
|
||||
-model /path/to/model \
|
||||
-prompt "Your prompt" \
|
||||
-grammar x/grammar/grammars/json.ebnf \
|
||||
-grammar-start value
|
||||
```
|
||||
|
||||
## Available Grammars
|
||||
|
||||
| File | Start Rule | Description |
|
||||
|------|------------|-------------|
|
||||
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
|
||||
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
|
||||
| `identifier.ebnf` | `ident` | Programming language identifiers |
|
||||
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
|
||||
| `list.ebnf` | `list` | Comma-separated word list |
|
||||
| `yesno.ebnf` | `response` | Simple yes/no responses |
|
||||
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
|
||||
| `email.ebnf` | `email` | Basic email addresses |
|
||||
| `phone.ebnf` | `phone` | US phone numbers |
|
||||
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
|
||||
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
|
||||
|
||||
## Grammar Syntax
|
||||
|
||||
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
|
||||
|
||||
The grammars use EBNF notation:
|
||||
|
||||
- `=` defines a production rule
|
||||
- `|` is alternation (or)
|
||||
- `{ }` is repetition (zero or more)
|
||||
- `[ ]` is optional (zero or one)
|
||||
- `" "` is a literal string
|
||||
- `…` is a character range (e.g., `"a" … "z"`)
|
||||
- `.` ends a production
|
||||
|
||||
## Writing Custom Grammars
|
||||
|
||||
1. Define your grammar in a `.ebnf` file
|
||||
2. Choose a start rule name
|
||||
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
|
||||
|
||||
Example custom grammar for RGB colors:
|
||||
|
||||
```ebnf
|
||||
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
```
|
||||
@@ -1,7 +0,0 @@
|
||||
expr = term { " OR " term } .
|
||||
term = factor { " AND " factor } .
|
||||
factor = "NOT " factor | atom | "(" expr ")" .
|
||||
atom = "true" | "false" | ident .
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
digit = "0" … "9" .
|
||||
@@ -1,6 +0,0 @@
|
||||
date = year "-" month "-" day .
|
||||
year = digit digit digit digit .
|
||||
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
|
||||
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
|
||||
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
@@ -1,5 +0,0 @@
|
||||
email = localpart "@" domain .
|
||||
localpart = word { "." word } .
|
||||
domain = word { "." word } .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
@@ -1,7 +0,0 @@
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
@@ -1,4 +0,0 @@
|
||||
color = "#" ( hex6 | hex3 ) .
|
||||
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hex3 = hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
@@ -1,3 +0,0 @@
|
||||
ident = letter { letter | digit | "_" } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
@@ -1,16 +0,0 @@
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
object = "{" [ members ] "}" .
|
||||
members = pair { "," pair } .
|
||||
pair = string ":" value .
|
||||
array = "[" [ elements ] "]" .
|
||||
elements = value { "," value } .
|
||||
string = "\"" { char } "\"" .
|
||||
char = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
onenine = "1" … "9" .
|
||||
digit = "0" … "9" .
|
||||
@@ -1,27 +0,0 @@
|
||||
root = array .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
@@ -1,4 +0,0 @@
|
||||
list = item { ", " item } .
|
||||
item = word .
|
||||
word = letter { letter } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
@@ -1,19 +0,0 @@
|
||||
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
|
||||
|
||||
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
|
||||
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
|
||||
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
@@ -1,15 +0,0 @@
|
||||
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
email_field = "\"email\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
@@ -1,7 +0,0 @@
|
||||
phone = parenformat | dashformat .
|
||||
parenformat = "(" areacode ") " exchange "-" subscriber .
|
||||
dashformat = areacode "-" exchange "-" subscriber .
|
||||
areacode = digit digit digit .
|
||||
exchange = digit digit digit .
|
||||
subscriber = digit digit digit digit .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
@@ -1,11 +0,0 @@
|
||||
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
|
||||
scheme = "http" | "https" .
|
||||
host = word { "." word } .
|
||||
port = digit { digit } .
|
||||
path = "/" { pathseg } .
|
||||
pathseg = word [ "/" ] .
|
||||
query = "?" param { "&" param } .
|
||||
param = word "=" word .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
digit = "0" … "9" .
|
||||
@@ -1,3 +0,0 @@
|
||||
response = affirmative | negative .
|
||||
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
|
||||
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .
|
||||
@@ -1,69 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
|
||||
// Based on https://www.json.org/json-en.html
|
||||
//
|
||||
// This grammar operates at the character level. The engine validates
|
||||
// tokens by matching them as sequences of these character-level terminals.
|
||||
const JSONGrammarEBNF = `
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
|
||||
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
|
||||
const JSONObjectGrammarEBNF = `
|
||||
json = object .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
@@ -1,726 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
)
|
||||
|
||||
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
|
||||
// See: https://platform.openai.com/docs/guides/structured-outputs
|
||||
type schemaNode struct {
|
||||
// Core types
|
||||
Type interface{} `json:"type"` // string, []string, or nil
|
||||
|
||||
// Object properties
|
||||
Properties map[string]*schemaNode `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
AdditionalProperties interface{} `json:"additionalProperties"`
|
||||
|
||||
// Array properties
|
||||
Items *schemaNode `json:"items"`
|
||||
MinItems *int `json:"minItems"`
|
||||
MaxItems *int `json:"maxItems"`
|
||||
|
||||
// String properties
|
||||
Pattern string `json:"pattern"` // Regex pattern
|
||||
Format string `json:"format"` // date-time, email, uuid, etc.
|
||||
|
||||
// Number properties (noted but not enforced in grammar - validated post-generation)
|
||||
Minimum *float64 `json:"minimum"`
|
||||
Maximum *float64 `json:"maximum"`
|
||||
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
|
||||
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
|
||||
MultipleOf *float64 `json:"multipleOf"`
|
||||
|
||||
// Enum and const
|
||||
Enum []interface{} `json:"enum"`
|
||||
Const interface{} `json:"const"`
|
||||
|
||||
// Composition
|
||||
AnyOf []*schemaNode `json:"anyOf"`
|
||||
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
|
||||
|
||||
// References and definitions
|
||||
Ref string `json:"$ref"`
|
||||
Defs map[string]*schemaNode `json:"$defs"`
|
||||
|
||||
// Description (ignored for grammar but useful for docs)
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// converter handles JSON Schema to EBNF conversion with state.
|
||||
type converter struct {
|
||||
schema *schemaNode
|
||||
definitions map[string]*schemaNode // Resolved $defs
|
||||
usedTypes map[string]bool
|
||||
rules []string
|
||||
ruleNum int
|
||||
definedRefs map[string]bool // Track which refs we've already defined as rules
|
||||
}
|
||||
|
||||
// EBNF converts a JSON Schema to EBNF grammar
|
||||
func EBNF(schemaJSON string) (string, error) {
|
||||
var schema schemaNode
|
||||
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
|
||||
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
|
||||
}
|
||||
|
||||
conv := &converter{
|
||||
schema: &schema,
|
||||
definitions: schema.Defs,
|
||||
usedTypes: make(map[string]bool),
|
||||
definedRefs: make(map[string]bool),
|
||||
}
|
||||
|
||||
return conv.convert()
|
||||
}
|
||||
|
||||
func (c *converter) convert() (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Generate root rule
|
||||
rootExpr := c.schemaToExpr(c.schema, "root")
|
||||
b.WriteString("root = ")
|
||||
b.WriteString(rootExpr)
|
||||
b.WriteString(" .\n")
|
||||
|
||||
// Add generated rules (refs, items, etc.)
|
||||
for _, rule := range c.rules {
|
||||
b.WriteString(rule)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add primitives based on usage
|
||||
c.addPrimitives(&b)
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func (c *converter) addPrimitives(b *strings.Builder) {
|
||||
if c.usedTypes["string"] {
|
||||
b.WriteString(`
|
||||
string = "\"" { character } "\"" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] {
|
||||
b.WriteString(`
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] {
|
||||
b.WriteString(`
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["integer"] {
|
||||
b.WriteString(`
|
||||
int = [ "-" ] ( "0" | onenine { digit } ) .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
|
||||
b.WriteString(`
|
||||
digit = "0" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
// onenine only needed for number/integer, not for digit-only formats
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] {
|
||||
b.WriteString(`onenine = "1" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
|
||||
b.WriteString(`
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["ws"] {
|
||||
b.WriteString(`
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
|
||||
if schema == nil {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
// Handle $ref first
|
||||
if schema.Ref != "" {
|
||||
return c.resolveRef(schema.Ref)
|
||||
}
|
||||
|
||||
// Handle const
|
||||
if schema.Const != nil {
|
||||
return c.constToExpr(schema.Const)
|
||||
}
|
||||
|
||||
// Handle enum
|
||||
if len(schema.Enum) > 0 {
|
||||
return c.enumToExpr(schema.Enum)
|
||||
}
|
||||
|
||||
// Handle anyOf / oneOf
|
||||
if len(schema.AnyOf) > 0 {
|
||||
return c.anyOfToExpr(schema.AnyOf, name)
|
||||
}
|
||||
if len(schema.OneOf) > 0 {
|
||||
return c.anyOfToExpr(schema.OneOf, name)
|
||||
}
|
||||
|
||||
// Handle type
|
||||
types := c.getTypes(schema.Type)
|
||||
if len(types) == 0 {
|
||||
// No type specified, could be anything
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
if len(types) == 1 {
|
||||
return c.typeToExpr(types[0], schema, name)
|
||||
}
|
||||
|
||||
// Multiple types (e.g., ["string", "null"])
|
||||
var parts []string
|
||||
for _, t := range types {
|
||||
parts = append(parts, c.typeToExpr(t, schema, name))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
|
||||
switch typeName {
|
||||
case "object":
|
||||
return c.objectToExpr(schema, name)
|
||||
case "array":
|
||||
return c.arrayToExpr(schema, name)
|
||||
case "string":
|
||||
return c.stringToExpr(schema, name)
|
||||
case "number":
|
||||
c.usedTypes["number"] = true
|
||||
return "number"
|
||||
case "integer":
|
||||
c.usedTypes["integer"] = true
|
||||
c.usedTypes["digit"] = true
|
||||
return "int"
|
||||
case "boolean":
|
||||
return `( "true" | "false" )`
|
||||
case "null":
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
if len(schema.Properties) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
// Sort properties for deterministic output
|
||||
// Required properties come first, in their required order
|
||||
var propOrder []string
|
||||
requiredSet := make(map[string]bool)
|
||||
for _, r := range schema.Required {
|
||||
requiredSet[r] = true
|
||||
propOrder = append(propOrder, r)
|
||||
}
|
||||
|
||||
// Add any non-required properties (though OpenAI requires all to be required)
|
||||
var optionalProps []string
|
||||
for propName := range schema.Properties {
|
||||
if !requiredSet[propName] {
|
||||
optionalProps = append(optionalProps, propName)
|
||||
}
|
||||
}
|
||||
sort.Strings(optionalProps)
|
||||
propOrder = append(propOrder, optionalProps...)
|
||||
|
||||
var propExprs []string
|
||||
first := true
|
||||
|
||||
for _, propName := range propOrder {
|
||||
propSchema, exists := schema.Properties[propName]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
propExpr := c.schemaToExpr(propSchema, propName)
|
||||
|
||||
prefix := ""
|
||||
if !first {
|
||||
prefix = `"," ws `
|
||||
}
|
||||
first = false
|
||||
|
||||
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
|
||||
}
|
||||
|
||||
if len(propExprs) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
|
||||
}
|
||||
|
||||
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
itemExpr := "value"
|
||||
if schema.Items != nil {
|
||||
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
|
||||
} else {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
}
|
||||
|
||||
// Create item rule
|
||||
c.ruleNum++
|
||||
itemRule := fmt.Sprintf("item%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
|
||||
|
||||
// Handle minItems/maxItems
|
||||
if schema.MinItems != nil || schema.MaxItems != nil {
|
||||
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
|
||||
}
|
||||
|
||||
// Default: zero or more items
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
|
||||
min := 0
|
||||
max := -1 // unlimited
|
||||
|
||||
if minItems != nil {
|
||||
min = *minItems
|
||||
}
|
||||
if maxItems != nil {
|
||||
max = *maxItems
|
||||
}
|
||||
|
||||
if min == 0 && max < 0 {
|
||||
// No constraints
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
if min == 0 && max == 0 {
|
||||
return `"[" ws "]"`
|
||||
}
|
||||
|
||||
// Build pattern for bounded array
|
||||
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
|
||||
var parts []string
|
||||
|
||||
// Required items
|
||||
for i := 0; i < min; i++ {
|
||||
if i > 0 {
|
||||
parts = append(parts, `"," ws`)
|
||||
}
|
||||
parts = append(parts, itemRule)
|
||||
}
|
||||
|
||||
// Optional items up to max
|
||||
if max > min {
|
||||
for i := min; i < max; i++ {
|
||||
if i == 0 {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
|
||||
}
|
||||
}
|
||||
// Close all optional brackets
|
||||
for i := min; i < max; i++ {
|
||||
parts = append(parts, "]")
|
||||
}
|
||||
} else if max < 0 {
|
||||
// Unlimited after min
|
||||
if min > 0 {
|
||||
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
|
||||
}
|
||||
}
|
||||
|
||||
if min == 0 {
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
|
||||
}
|
||||
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
|
||||
// Handle format
|
||||
if schema.Format != "" {
|
||||
return c.formatToExpr(schema.Format)
|
||||
}
|
||||
|
||||
// Handle pattern (regex)
|
||||
if schema.Pattern != "" {
|
||||
return c.patternToExpr(schema.Pattern, name)
|
||||
}
|
||||
|
||||
// Default string
|
||||
c.usedTypes["string"] = true
|
||||
if name == "root" {
|
||||
c.usedTypes["character"] = true
|
||||
return `"\"" { character } "\""`
|
||||
}
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) formatToExpr(format string) string {
|
||||
switch format {
|
||||
case "date":
|
||||
// YYYY-MM-DD
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("date%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "time":
|
||||
// HH:MM:SS
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("time%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "date-time":
|
||||
// YYYY-MM-DDTHH:MM:SSZ or with offset
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "email":
|
||||
// Simplified email pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("email%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
|
||||
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
|
||||
return ruleName
|
||||
|
||||
case "uuid":
|
||||
// 8-4-4-4-12 hex pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
|
||||
c.usedTypes["hex"] = true
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "ipv4":
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "uri", "hostname":
|
||||
// Fallback to general string for complex formats
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) patternToExpr(pattern string, name string) string {
|
||||
// Try to convert simple regex patterns to EBNF
|
||||
// This handles common cases; complex regex falls back to string
|
||||
|
||||
// Remove anchors
|
||||
pattern = strings.TrimPrefix(pattern, "^")
|
||||
pattern = strings.TrimSuffix(pattern, "$")
|
||||
|
||||
// Try to parse and convert
|
||||
expr, ok := c.regexToEBNF(pattern)
|
||||
if !ok {
|
||||
// Fallback to general string
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) regexToEBNF(pattern string) (string, bool) {
|
||||
// Simple regex to EBNF converter
|
||||
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
|
||||
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
|
||||
for i < len(pattern) {
|
||||
ch := pattern[i]
|
||||
|
||||
switch ch {
|
||||
case '[':
|
||||
// Character class
|
||||
end := strings.Index(pattern[i:], "]")
|
||||
if end == -1 {
|
||||
return "", false
|
||||
}
|
||||
class := pattern[i+1 : i+end]
|
||||
ebnfClass, ok := c.charClassToEBNF(class)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString(ebnfClass)
|
||||
i += end + 1
|
||||
|
||||
case '(':
|
||||
// Group - find matching )
|
||||
depth := 1
|
||||
start := i + 1
|
||||
j := start
|
||||
for j < len(pattern) && depth > 0 {
|
||||
if pattern[j] == '(' {
|
||||
depth++
|
||||
} else if pattern[j] == ')' {
|
||||
depth--
|
||||
}
|
||||
j++
|
||||
}
|
||||
if depth != 0 {
|
||||
return "", false
|
||||
}
|
||||
groupContent := pattern[start : j-1]
|
||||
groupExpr, ok := c.regexToEBNF(groupContent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString("( ")
|
||||
result.WriteString(groupExpr)
|
||||
result.WriteString(" )")
|
||||
i = j
|
||||
|
||||
case '|':
|
||||
result.WriteString(" | ")
|
||||
i++
|
||||
|
||||
case '+':
|
||||
// One or more - wrap previous in { } and add one required
|
||||
// This is a simplification
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '*':
|
||||
// Zero or more - need to wrap previous
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '?':
|
||||
// Optional - need to wrap previous in [ ]
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '\\':
|
||||
// Escape sequence
|
||||
if i+1 >= len(pattern) {
|
||||
return "", false
|
||||
}
|
||||
next := pattern[i+1]
|
||||
switch next {
|
||||
case 'd':
|
||||
result.WriteString("digit")
|
||||
c.usedTypes["digit"] = true
|
||||
case 'w':
|
||||
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
|
||||
case 's':
|
||||
result.WriteString(`( " " | "\t" )`)
|
||||
default:
|
||||
result.WriteString(fmt.Sprintf(`"%c"`, next))
|
||||
}
|
||||
i += 2
|
||||
|
||||
default:
|
||||
// Literal character
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
} else {
|
||||
// Special char, try to escape
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(result.String()), true
|
||||
}
|
||||
|
||||
func (c *converter) charClassToEBNF(class string) (string, bool) {
|
||||
// Handle character classes like a-z, A-Z, 0-9
|
||||
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
|
||||
}
|
||||
if class == "a-zA-Z0-9" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
|
||||
}
|
||||
if class == "a-z" {
|
||||
return `"a" … "z"`, true
|
||||
}
|
||||
if class == "A-Z" {
|
||||
return `"A" … "Z"`, true
|
||||
}
|
||||
if class == "0-9" {
|
||||
c.usedTypes["digit"] = true
|
||||
return "digit", true
|
||||
}
|
||||
|
||||
// Try to parse range patterns
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
|
||||
var parts []string
|
||||
for i, s := range schemas {
|
||||
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
|
||||
parts = append(parts, expr)
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) enumToExpr(values []interface{}) string {
|
||||
var parts []string
|
||||
for _, v := range values {
|
||||
parts = append(parts, c.constToExpr(v))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) constToExpr(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
|
||||
case float64:
|
||||
if val == float64(int(val)) {
|
||||
return fmt.Sprintf(`"%d"`, int(val))
|
||||
}
|
||||
return fmt.Sprintf(`"%v"`, val)
|
||||
case bool:
|
||||
if val {
|
||||
return `"true"`
|
||||
}
|
||||
return `"false"`
|
||||
case nil:
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) resolveRef(ref string) string {
|
||||
// Handle #/$defs/name references
|
||||
if strings.HasPrefix(ref, "#/$defs/") {
|
||||
defName := strings.TrimPrefix(ref, "#/$defs/")
|
||||
return c.resolveDefRef(defName)
|
||||
}
|
||||
|
||||
// Handle root recursion #
|
||||
if ref == "#" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
// Unknown ref format
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) resolveDefRef(defName string) string {
|
||||
// Check if we've already defined this as a rule
|
||||
ruleName := "def_" + defName
|
||||
if c.definedRefs[defName] {
|
||||
return ruleName
|
||||
}
|
||||
|
||||
// Mark as defined to prevent infinite recursion
|
||||
c.definedRefs[defName] = true
|
||||
|
||||
// Look up the definition
|
||||
if c.definitions == nil {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
defSchema, ok := c.definitions[defName]
|
||||
if !ok {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Generate the rule
|
||||
expr := c.schemaToExpr(defSchema, ruleName)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
|
||||
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) getTypes(t interface{}) []string {
|
||||
switch v := t.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []interface{}:
|
||||
var types []string
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
types = append(types, s)
|
||||
}
|
||||
}
|
||||
return types
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *converter) escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
return s
|
||||
}
|
||||
|
||||
// Grammar converts a JSON Schema string into a compiled grammar.
|
||||
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
|
||||
ebnf, err := EBNF(schemaJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grammar.ParseEBNF(ebnf, "root")
|
||||
}
|
||||
@@ -1,336 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gram "github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestJSONEBNF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "with enum",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"enum": ["active", "inactive", "pending"]}
|
||||
},
|
||||
"required": ["status"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array of objects",
|
||||
schema: `{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}
|
||||
},
|
||||
"required": ["user"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to compile it
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrammarEngine(t *testing.T) {
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`
|
||||
|
||||
grammar, err := Grammar(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Grammar failed: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"\"name\"", "\"age\"", "\"test\"",
|
||||
"\"", "a", "b", "c",
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
" ", "\n",
|
||||
"true", "false", "null",
|
||||
}
|
||||
|
||||
engine, err := gram.NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("grammar.NewEngine failed: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Test that we can apply mask
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
|
||||
func TestOpenAIStructuredOutputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "anyOf union",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nullable string via type array",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "$ref with $defs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"$ref": "#/$defs/Person"}
|
||||
},
|
||||
"required": ["person"],
|
||||
"$defs": {
|
||||
"Person": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "const value",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "user"}
|
||||
},
|
||||
"required": ["type"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date-time",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": {"type": "string", "format": "date-time"}
|
||||
},
|
||||
"required": ["created"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"}
|
||||
},
|
||||
"required": ["birthday"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format email",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format uuid",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "format": "uuid"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array with minItems maxItems",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 3
|
||||
}
|
||||
},
|
||||
"required": ["tags"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "deeply nested with refs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"employees": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/Employee"}
|
||||
}
|
||||
},
|
||||
"required": ["name", "employees"]
|
||||
}
|
||||
},
|
||||
"required": ["company"],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"enum": ["engineer", "manager", "intern"]}
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple refs same def",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from": {"$ref": "#/$defs/Address"},
|
||||
"to": {"$ref": "#/$defs/Address"}
|
||||
},
|
||||
"required": ["from", "to"],
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"zip": {"type": "string"}
|
||||
},
|
||||
"required": ["city", "zip"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oneOf variant",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"success": {"type": "boolean"}},
|
||||
"required": ["success"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"error": {"type": "string"}},
|
||||
"required": ["error"]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["result"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// terminalType distinguishes different kinds of grammar terminals
|
||||
type terminalType int
|
||||
|
||||
const (
|
||||
terminalLiteral terminalType = iota // Exact string: "true", "{"
|
||||
terminalRange // Character range: [a-z], [0-9]
|
||||
)
|
||||
|
||||
// terminal represents a compiled grammar terminal
|
||||
type terminal struct {
|
||||
ID int
|
||||
Type terminalType
|
||||
Pattern string // Original pattern from grammar
|
||||
Unescaped string // Unescaped literal (for terminalLiteral)
|
||||
LowRune rune // For unicode ranges: low bound
|
||||
HighRune rune // For unicode ranges: high bound
|
||||
}
|
||||
|
||||
// terminalMatch represents a terminal that matched at a position
|
||||
type terminalMatch struct {
|
||||
TerminalID int
|
||||
Length int // Number of bytes consumed
|
||||
}
|
||||
|
||||
// trieNode is a node in the literal matching trie
|
||||
type trieNode struct {
|
||||
children [256]*trieNode // Byte-indexed children
|
||||
terminalID int // -1 if not accepting, else terminal ID
|
||||
}
|
||||
|
||||
// terminalMatcher tests which terminals match at a position in a byte slice
|
||||
type terminalMatcher struct {
|
||||
// Trie for literal matching (fast path)
|
||||
literalTrie *trieNode
|
||||
|
||||
// Range terminals (single-byte matches)
|
||||
ranges []terminal
|
||||
|
||||
// All terminals for enumeration
|
||||
terminals []terminal
|
||||
|
||||
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
|
||||
patternToID map[string]int
|
||||
}
|
||||
|
||||
// addLiteralToTrie adds a literal pattern to the trie
|
||||
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
|
||||
node := m.literalTrie
|
||||
for i := 0; i < len(pattern); i++ {
|
||||
c := pattern[i]
|
||||
if node.children[c] == nil {
|
||||
node.children[c] = &trieNode{terminalID: -1}
|
||||
}
|
||||
node = node.children[c]
|
||||
}
|
||||
node.terminalID = terminalID
|
||||
}
|
||||
|
||||
// matchesAt returns all terminals that match at pos in data
|
||||
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
|
||||
if pos >= len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var matches []terminalMatch
|
||||
|
||||
// Check literal matches via trie
|
||||
node := m.literalTrie
|
||||
for i := pos; i < len(data) && node != nil; i++ {
|
||||
c := data[i]
|
||||
node = node.children[c]
|
||||
if node != nil && node.terminalID >= 0 {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: node.terminalID,
|
||||
Length: i - pos + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check range matches (unicode-aware)
|
||||
r, runeLen := utf8.DecodeRune(data[pos:])
|
||||
if r != utf8.RuneError {
|
||||
for _, rng := range m.ranges {
|
||||
if r >= rng.LowRune && r <= rng.HighRune {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: rng.ID,
|
||||
Length: runeLen,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// terminalCount returns the number of terminals
|
||||
func (m *terminalMatcher) terminalCount() int {
|
||||
return len(m.terminals)
|
||||
}
|
||||
@@ -234,3 +234,17 @@ ollama create z-image
|
||||
3. Copy config files (*.json) as config layers
|
||||
4. Write manifest
|
||||
```
|
||||
|
||||
## FP8 Quantization
|
||||
|
||||
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image-fp8 --quantize fp8
|
||||
```
|
||||
|
||||
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
|
||||
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractPath(content string) string {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
return strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imagePath, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imagePath == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
resp.Data[0].URL = "file://" + imagePath
|
||||
} else {
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err == nil {
|
||||
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
197
x/imagegen/cache/teacache.go
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
|
||||
// It caches the transformer output and reuses it when timestep values
|
||||
// are similar between consecutive steps.
|
||||
//
|
||||
// For CFG (classifier-free guidance), it caches pos and neg predictions
|
||||
// separately and always computes CFG fresh to avoid error amplification.
|
||||
//
|
||||
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
|
||||
// https://github.com/ali-vilab/TeaCache
|
||||
type TeaCache struct {
|
||||
// Cached transformer output from last computed step (non-CFG mode)
|
||||
cachedOutput *mlx.Array
|
||||
|
||||
// Cached CFG outputs (pos and neg separately)
|
||||
cachedPosOutput *mlx.Array
|
||||
cachedNegOutput *mlx.Array
|
||||
|
||||
// Previous timestep value for difference calculation
|
||||
prevTimestep float32
|
||||
|
||||
// Accumulated difference for rescaling
|
||||
accumulatedDiff float32
|
||||
|
||||
// Configuration
|
||||
threshold float32 // Threshold for recomputation decision
|
||||
rescaleFactor float32 // Model-specific rescaling factor
|
||||
skipEarlySteps int // Number of early steps to never cache
|
||||
|
||||
// Statistics
|
||||
cacheHits int
|
||||
cacheMisses int
|
||||
}
|
||||
|
||||
// TeaCacheConfig holds configuration for TeaCache.
|
||||
type TeaCacheConfig struct {
|
||||
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
|
||||
// Recommended: 0.05-0.15 for image models
|
||||
Threshold float32
|
||||
|
||||
// Rescale factor to adjust timestep embedding differences.
|
||||
// Model-specific, typically 1.0-2.0
|
||||
RescaleFactor float32
|
||||
|
||||
// SkipEarlySteps: number of early steps to always compute (never cache).
|
||||
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
|
||||
SkipEarlySteps int
|
||||
}
|
||||
|
||||
// DefaultTeaCacheConfig returns default configuration for TeaCache.
|
||||
func DefaultTeaCacheConfig() *TeaCacheConfig {
|
||||
return &TeaCacheConfig{
|
||||
Threshold: 0.1,
|
||||
RescaleFactor: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTeaCache creates a new TeaCache instance.
|
||||
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
|
||||
if cfg == nil {
|
||||
cfg = DefaultTeaCacheConfig()
|
||||
}
|
||||
return &TeaCache{
|
||||
threshold: cfg.Threshold,
|
||||
rescaleFactor: cfg.RescaleFactor,
|
||||
skipEarlySteps: cfg.SkipEarlySteps,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldCompute determines if we should compute the full forward pass
|
||||
// or reuse the cached output based on timestep similarity.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. First step always computes
|
||||
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
|
||||
// 3. If accumulated difference > threshold, compute new output
|
||||
// 4. Otherwise, reuse cached output
|
||||
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
|
||||
// Always compute early steps (critical for structure)
|
||||
// Check both regular cache and CFG cache
|
||||
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
|
||||
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
// Compute absolute difference between current and previous timestep
|
||||
diff := timestep - tc.prevTimestep
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
|
||||
// Apply rescaling factor
|
||||
scaledDiff := diff * tc.rescaleFactor
|
||||
|
||||
// Accumulate difference (helps track drift over multiple cached steps)
|
||||
tc.accumulatedDiff += scaledDiff
|
||||
|
||||
// Decision based on accumulated difference
|
||||
if tc.accumulatedDiff > tc.threshold {
|
||||
tc.accumulatedDiff = 0 // Reset accumulator
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
|
||||
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
|
||||
// Free previous cached output
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedOutput = output
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
|
||||
// This allows CFG to be computed fresh each step, avoiding error amplification.
|
||||
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
|
||||
// Free previous cached outputs
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedPosOutput = posOutput
|
||||
tc.cachedNegOutput = negOutput
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// GetCached returns the cached output (non-CFG mode).
|
||||
func (tc *TeaCache) GetCached() *mlx.Array {
|
||||
tc.cacheHits++
|
||||
return tc.cachedOutput
|
||||
}
|
||||
|
||||
// GetCFGCached returns cached pos and neg outputs for CFG mode.
|
||||
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
|
||||
tc.cacheHits++
|
||||
return tc.cachedPosOutput, tc.cachedNegOutput
|
||||
}
|
||||
|
||||
// HasCFGCache returns true if CFG cache is available.
|
||||
func (tc *TeaCache) HasCFGCache() bool {
|
||||
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
|
||||
}
|
||||
|
||||
// Arrays returns all arrays that should be kept alive.
|
||||
func (tc *TeaCache) Arrays() []*mlx.Array {
|
||||
var arrays []*mlx.Array
|
||||
if tc.cachedOutput != nil {
|
||||
arrays = append(arrays, tc.cachedOutput)
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
arrays = append(arrays, tc.cachedPosOutput)
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
arrays = append(arrays, tc.cachedNegOutput)
|
||||
}
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Stats returns cache hit/miss statistics.
|
||||
func (tc *TeaCache) Stats() (hits, misses int) {
|
||||
return tc.cacheHits, tc.cacheMisses
|
||||
}
|
||||
|
||||
// Free releases all cached arrays.
|
||||
func (tc *TeaCache) Free() {
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
tc.cachedOutput = nil
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
tc.cachedPosOutput = nil
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
tc.cachedNegOutput = nil
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,75 +38,17 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// Show displays information about an image generation model.
|
||||
func Show(modelName string, w io.Writer) error {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Count total size
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
|
||||
// Read model_index.json for architecture
|
||||
var architecture string
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
architecture = index.Architecture
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
|
||||
paramCount := totalSize / 2
|
||||
paramStr := formatParamCount(paramCount)
|
||||
|
||||
// Print Model info
|
||||
fmt.Fprintln(w, " Model")
|
||||
if architecture != "" {
|
||||
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
|
||||
}
|
||||
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
|
||||
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
// Print Capabilities
|
||||
fmt.Fprintln(w, " Capabilities")
|
||||
fmt.Fprintf(w, " %s\n", "image")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatParamCount formats parameter count as human-readable string.
|
||||
func formatParamCount(count int64) string {
|
||||
if count >= 1_000_000_000 {
|
||||
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
|
||||
}
|
||||
if count >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
|
||||
}
|
||||
return fmt.Sprintf("%d", count)
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -121,11 +62,6 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Verify it's a valid image gen model
|
||||
if ResolveModelName(name) == "" {
|
||||
return fmt.Errorf("unknown image generation model: %s", name)
|
||||
}
|
||||
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
if cmd != nil && cmd.Flags() != nil {
|
||||
@@ -161,17 +97,12 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -183,31 +114,21 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -218,9 +139,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
if imagePath != "" {
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 and save to CWD
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Create filename from prompt
|
||||
safeName := sanitizeFilename(prompt)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to save image: %w", err)
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -289,12 +228,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -306,31 +242,22 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -342,25 +269,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
continue
|
||||
}
|
||||
|
||||
// Copy image to current directory with descriptive name
|
||||
if imagePath != "" {
|
||||
// Save image to current directory with descriptive name
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create filename from prompt (sanitized)
|
||||
safeName := sanitizeFilename(line)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
// Copy file to CWD
|
||||
if err := copyFile(imagePath, newName); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
} else {
|
||||
displayImageInTerminal(newName)
|
||||
fmt.Printf("Image saved to: %s\n", newName)
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
@@ -381,24 +313,6 @@ func sanitizeFilename(s string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst.
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
@@ -509,10 +423,7 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
// Send in chunks for large images
|
||||
const chunkSize = 4096
|
||||
for i := 0; i < len(encoded); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(encoded) {
|
||||
end = len(encoded)
|
||||
}
|
||||
end := min(i+chunkSize, len(encoded))
|
||||
chunk := encoded[i:end]
|
||||
|
||||
if i == 0 {
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -110,11 +109,7 @@ type input struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
JSONMode bool // Enable JSON grammar constraint
|
||||
GrammarEBNF string // Raw EBNF grammar string
|
||||
GrammarStart string // Start rule name for grammar
|
||||
Vocab []string // Vocabulary for constrained decoding
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
}
|
||||
|
||||
type output struct {
|
||||
@@ -132,11 +127,9 @@ type Decoder struct {
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
grammar *grammar.Engine // Optional grammar constraint engine
|
||||
grammarVocab []string // Vocab for grammar debug
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
@@ -152,12 +145,6 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGrammar enables constrained decoding with the given grammar engine.
|
||||
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
|
||||
d.grammar = g
|
||||
d.grammarVocab = vocab
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
@@ -235,16 +222,6 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
@@ -268,15 +245,6 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Sync on previous token FIRST to get its value and update grammar state
|
||||
// This must happen before computing the next mask
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Update grammar state with the token we just synced
|
||||
if d.grammar != nil {
|
||||
d.grammar.Accept(int(val))
|
||||
}
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
@@ -285,18 +253,6 @@ func (d *Decoder) step() int32 {
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
// Get last position logits: [1, 1, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
// Reshape back to [1, 1, vocab] for sample()
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
@@ -306,6 +262,9 @@ func (d *Decoder) step() int32 {
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
@@ -330,48 +289,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Set up grammar constraint if enabled
|
||||
var grammarEngine *grammar.Engine
|
||||
var grammarVocab []string
|
||||
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
|
||||
if in.GrammarEBNF != "" {
|
||||
// Custom EBNF grammar
|
||||
startRule := in.GrammarStart
|
||||
if startRule == "" {
|
||||
startRule = "root"
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse grammar: %w", err)
|
||||
}
|
||||
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
|
||||
} else {
|
||||
// JSON object grammar (only allows objects at top level)
|
||||
compiled, err = grammar.JSONObjectGrammar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JSON grammar: %w", err)
|
||||
}
|
||||
fmt.Println("[JSON object mode enabled]")
|
||||
}
|
||||
|
||||
// Pad vocab to match model's vocab size if needed
|
||||
grammarVocab = in.Vocab
|
||||
modelVocabSize := int(m.VocabSize())
|
||||
if len(grammarVocab) < modelVocabSize {
|
||||
padded := make([]string, modelVocabSize)
|
||||
copy(padded, grammarVocab)
|
||||
grammarVocab = padded
|
||||
}
|
||||
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create grammar engine: %w", err)
|
||||
}
|
||||
defer grammarEngine.Close()
|
||||
}
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
@@ -387,10 +304,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
if grammarEngine != nil {
|
||||
dec.SetGrammar(grammarEngine, grammarVocab)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
@@ -414,11 +327,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete after first token
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
@@ -433,10 +341,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete (valid JSON document finished)
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
break
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
|
||||