Compare commits

..

19 Commits

Author SHA1 Message Date
jmorganca
bb1a5617b6 readme: add instructions to build with MLX 2026-01-15 09:52:56 -08:00
jmorganca
0d3648c1be glm-image wip 2026-01-14 16:46:50 -08:00
Daniel Hiltgen
02a2401596 mlx: bundle openblas dependency (#13706) 2026-01-13 15:29:47 -08:00
Daniel Hiltgen
e4b488a7b5 CI: dedup cuda libraries to reduce payload size (#13704) 2026-01-13 11:25:31 -08:00
Daniel Hiltgen
98079ddd79 ci: add missing mlx components to release build (#13702) 2026-01-13 09:13:09 -08:00
Jeffrey Morgan
d70942f47b x/imagegen/cli: skip local model check (#13699) 2026-01-12 22:38:10 -08:00
Jeffrey Morgan
58e4701557 scripts: increase notarization timeout to 20m (#13697)
The 100MB mlx.metallib file significantly increased the app bundle size,
causing Apple's notarization service to timeout with the previous 10m limit.
2026-01-12 20:38:38 -08:00
Jeffrey Morgan
dbf47ee55a cmake: use CMAKE_SYSTEM_PROCESSOR instead of CMAKE_OSX_ARCHITECTURES for mlx.metallib install (#13696)
The CMake condition for installing mlx.metallib checks
CMAKE_OSX_ARCHITECTURES, but this variable is only set when explicitly
passed - not auto-detected. The arm64 build was missing this flag,
causing the metallib to not be installed, which then caused codesign
to fail on the unexpanded glob pattern.
2026-01-12 20:05:11 -08:00
Jeffrey Morgan
af7ea6e96e x/imagegen: install mlx.metallib and fix macOS rpath handling, add mlx library directories to LD_LIBRARY_PATH (#13695)
- Install mlx.metallib for arm64 builds (required for Metal GPU acceleration)
- Apply rpath settings to all macOS builds, not just x86_64
- Add CMAKE_BUILD_WITH_INSTALL_RPATH to avoid install_name_tool errors
- Update build_darwin.sh to copy, sign, and package the metallib
2026-01-12 19:03:11 -08:00
Jeffrey Morgan
8f1e0140e7 x/imagegen: fix mlx build in Dockerfile and macOS build script (#13693) 2026-01-12 15:52:43 -08:00
Parth Sareen
35c3c9e3c2 anthropic: allow non-thinking models when using Anthropic API (#13692) 2026-01-12 15:13:26 -08:00
Parth Sareen
d06acbcb19 x/cmd: enable web search and web fetch with flag (#13690) 2026-01-12 13:59:40 -08:00
Jeffrey Morgan
9667c2282f x/imagegen: add naive TeaCache and FP8 quantization support (#13683)
TeaCache:
- Timestep embedding similarity caching for diffusion models
- Polynomial rescaling with configurable thresholds
- Reduces transformer forward passes by ~30-50%

FP8 quantization:
- Support for FP8 quantized models (8-bit weights with scales)
- QuantizedMatmul on Metal, Dequantize on CUDA
- Client-side quantization via ollama create --quantize fp8

Other bug fixes:
- Fix `/api/show` API for image generation models
- Server properly returns model info (architecture, parameters, quantization)
- Memory allocation optimizations
- CLI improvements for image generation
2026-01-12 13:45:22 -08:00
Jeffrey Morgan
a937a68317 server: fix slow 'ollama rm' of models with many layers (#13680)
RemoveLayers was calling Manifests() for each layer to check if it was
shared with other models. For models with many blobs (e.g., tensor
models), this caused O(N*M) manifest reads.

Now loads manifests once and builds a set of in-use digests.
2026-01-12 13:17:48 -08:00
Parth Sareen
2185112d84 x/cmd: connect /set flags to behavior in experimental mode (#13684) 2026-01-12 00:40:44 -08:00
Parth Sareen
91926601dc x: add missing /set, /show, /load, /save commands to experimental mode (#13682) 2026-01-11 23:12:31 -08:00
Jeffrey Morgan
361d6c16c2 x/imagegen/transfer: fix timeout and progress reporting (#13679)
Removes 5-minute HTTP client timeout that caused "context deadline
exceeded" errors on large file downloads. Stall detection (10s)
already handles unresponsive connections.

Fixes progress bar total going down on resume by calculating total
from all blobs upfront and reporting already-downloaded bytes
as completed immediately.
2026-01-11 15:33:53 -08:00
Patrick Devine
7e2496e88e Fix cmake install command in README (#13678)
Update installation command for MLX component in README.
2026-01-11 13:16:42 -08:00
WhatToPutHere
5b84e29882 docs: fix troubleshooting page (#13674)
Updated the link in the log output description to point to the correct troubleshooting guide format.
2026-01-11 00:58:07 -08:00
91 changed files with 6698 additions and 6233 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -372,13 +372,17 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
if(APPLE)
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS

View File

@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -182,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -421,7 +453,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -493,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +668,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +677,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 8 * format.MegaByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -100,7 +100,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -464,14 +465,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -533,9 +526,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -565,7 +567,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts)
@@ -671,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -837,11 +843,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -1786,6 +1787,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)

View File

@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -1,3 +0,0 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

18
go.mod
View File

@@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.19.0
golang.org/x/sys v0.39.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
)
require (
@@ -30,8 +30,8 @@ require (
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.31.0
golang.org/x/tools v0.40.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
gonum.org/v1/gonum v0.15.0
)
@@ -81,11 +81,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
golang.org/x/net v0.48.0 // indirect
golang.org/x/term v0.38.0
golang.org/x/text v0.32.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
go.sum
View File

@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
return
}
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))

View File

@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -73,7 +73,7 @@ _build_darwin() {
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,19 +82,19 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
chmod +x dist/darwin/ollama-mlx
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
# create a temporary zip for notarization
TEMP=$(mktemp -u).zip
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
rm -f "$TEMP"
fi
@@ -154,23 +154,25 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
# Copy MLX metallib (architecture-independent, just use arm64 version)
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -178,11 +180,11 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
@@ -206,7 +208,7 @@ _build_macapp() {
rm -f dist/rw*.dmg
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f stapler) staple dist/Ollama.dmg
else
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"

View File

@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
fi
# buildx behavior changes for single vs. multiplatform

View File

@@ -0,0 +1,60 @@
#!/bin/sh
#
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
# This script finds identical .so* files in mlx_cuda_* directories that exist
# in corresponding cuda_* directories and replaces them with symlinks.
#
set -eu
if [ $# -eq 0 ]; then
echo "ERROR: No directory specified" >&2
echo "Usage: $0 <base_directory>" >&2
exit 1
fi
base_dir="$1"
if [ ! -d "${base_dir}" ]; then
echo "ERROR: Directory ${base_dir} does not exist" >&2
exit 1
fi
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
echo "Deduplication complete"

View File

@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
ms, err := Manifests(true)
if err != nil {
return err
}
// Build set of digests still in use by other manifests
inUse := make(map[string]struct{})
for _, other := range ms {
for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" {
inUse[layer.Digest] = struct{}{}
}
}
}
// Remove layers not used by any other manifest
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == "" {
continue
}
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
}
}
return nil
}

View File

@@ -1124,6 +1124,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1206,6 +1215,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
@@ -2059,8 +2072,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
// Set think to nil when being used with Anthropic API to connect to tools like claude code
if _, ok := c.Get("relax_thinking"); ok {
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
req.Think = nil
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
}

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -41,6 +41,7 @@ var optionLabels = []string{
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
"web_fetch": "Web Fetch",
}
// ToolDisplayName returns the human-readable display name for a tool.
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
}
// For web fetch, show URL and internet notice
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 {
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
}
}
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
// Truncate long URLs
if len(url) > 50 {
url = url[:47] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"strings"
"syscall"
"time"
@@ -24,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -130,6 +139,7 @@ type RunOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
Verbose bool
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
@@ -178,6 +188,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
var latest api.ChatResponse
role := "assistant"
messages := opts.Messages
@@ -187,6 +198,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
}
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
@@ -483,6 +495,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Println()
}
if opts.Verbose {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
@@ -634,7 +650,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
// If enableWebsearch is true, the web search tool is registered.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -660,6 +677,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if supportsTools {
toolRegistry = tools.DefaultRegistry()
// Register web search and web fetch tools if enabled via flag
if enableWebsearch {
toolRegistry.RegisterWebSearch()
toolRegistry.RegisterWebFetch()
}
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
@@ -667,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr)
}
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
}
@@ -677,6 +705,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -688,13 +719,39 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -707,6 +764,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load Load a different model")
fmt.Fprintln(os.Stderr, " /save Save session as a model")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
@@ -716,6 +777,303 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
thinkValue.Value = maybeLevel
}
think = &thinkValue
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
think = &api.ThinkValue{Value: false}
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
}
continue
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.ShowRequest{
Name: modelName,
Options: options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
continue
}
switch args[1] {
case "info":
fmt.Fprintf(os.Stderr, " Model\n")
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
if resp.Details.Family != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
}
if resp.Details.ParameterSize != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
}
if resp.Details.QuantizationLevel != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
}
if len(resp.Capabilities) > 0 {
caps := make([]string, len(resp.Capabilities))
for i, c := range resp.Capabilities {
caps[i] = string(c)
}
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
}
fmt.Fprintln(os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
if len(options) > 0 {
fmt.Println("\nUser defined parameters:")
for k, v := range options {
fmt.Printf(" %-30s %v\n", k, v)
}
}
case "system":
switch {
case system != "":
fmt.Println(system + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
}
continue
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModelName := args[1]
fmt.Printf("Loading model '%s'\n", newModelName)
// Create progress spinner
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// Get client
client, err := api.ClientFromEnvironment()
if err != nil {
p.StopAndClear()
fmt.Println("error: couldn't connect to ollama server")
continue
}
// Check if model exists and get its info
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else {
fmt.Printf("error: %v\n", err)
}
continue
}
// For cloud models, no need to preload
if info.RemoteHost == "" {
// Preload the model by sending an empty generate request
req := &api.GenerateRequest{
Model: newModelName,
Think: think,
}
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
return nil
})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
} else {
fmt.Printf("error loading model: %v\n", err)
}
continue
}
}
p.StopAndClear()
modelName = newModelName
messages = []api.Message{}
approval.Reset()
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.CreateRequest{
Model: args[1],
From: modelName,
Parameters: options,
Messages: messages,
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Printf("error: %v\n", err)
continue
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
@@ -723,14 +1081,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 {
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
verbose, _ := cmd.Flags().GetBool("verbose")
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Format: format,
Options: options,
Think: think,
HideThinking: hideThinking,
@@ -738,6 +1098,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
Verbose: verbose,
}
assistant, err := Chat(cmd.Context(), opts)

View File

@@ -1,185 +0,0 @@
# grammar
Grammar-constrained decoding for LLM outputs using MLX.
## Performance
Performance depends on hardware, vocabulary size, grammar, and whether you
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
setup.
### Design choices that keep masking fast
| Technique | Impact |
|-----------|--------|
| Precomputed token analysis | Terminal matches computed once at startup |
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
| Partitioned tokens | Exact matches separated from DP candidates |
### Comparison Notes
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
## Usage
```go
import (
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
)
// Use built-in JSON grammar
g, _ := grammar.JSONGrammar()
// Or from JSON Schema (OpenAI-compatible)
g, _ := schema.Grammar(`{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`)
// Or parse custom EBNF
g, _ := grammar.ParseEBNF(myGrammar, "root")
// Create engine with model vocabulary
engine, _ := grammar.NewEngine(g, vocab)
defer engine.Close()
// Generation loop
for !engine.IsComplete() {
logits := model.Forward(tokens)
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
nextToken := sample(masked)
engine.Accept(nextToken)
}
// Output conforms to the grammar when you only sample from masked tokens and call Accept
```
## EBNF Syntax
```ebnf
rule = expression . # Rule definition (ends with .)
"literal" # Literal string
"a" "z" # Character range (inclusive)
( a | b ) # Grouping with alternation
[ optional ] # Optional (0 or 1)
{ repeated } # Repetition (0 or more)
```
### Example: JSON Grammar
```ebnf
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .
```
### Example: Custom Schema
```ebnf
root = "{" ws name_field "," ws age_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
string = "\"" { char } "\"" .
char = " " | "!" | "#" "~" .
number = [ "-" ] digit { digit } .
digit = "0" "9" .
ws = { " " | "\n" } .
```
## JSON Schema Support
OpenAI-compatible JSON Schema support with automatic EBNF generation:
```go
schema := `{
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"required": ["user"],
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string", "format": "email"},
"role": {"enum": ["admin", "user", "guest"]}
},
"required": ["name", "email", "role"]
}
}
}`
grammar, _ := schema.Grammar(schema)
```
### Supported Features
| Feature | Example |
|---------|---------|
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
| Objects | `properties`, `required` |
| Arrays | `items`, `minItems`, `maxItems` |
| Enums | `enum: ["a", "b", "c"]` |
| Constants | `const: "value"` |
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
| References | `$ref: "#/$defs/Name"`, `$defs` |
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
## Benchmarks
```bash
# Run all tests
go test -tags mlx ./x/grammar/...
# Run benchmarks
go test -tags mlx ./x/grammar/ -bench=.
# Compare with llama.cpp (outputs JSON)
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
# Compare with a more complex schema
go run -tags mlx ./x/grammar/cmd/compare \
-gbnf x/grammar/cmd/compare/complex.gbnf \
-schema x/grammar/cmd/compare/complex.schema.json \
-vocab-size 128000 -iterations 500
```
## References
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs

View File

@@ -1,161 +0,0 @@
//go:build mlx
package grammar
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
type terminalTokenGroups struct {
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
ExactMatches []int32
// DPCandidates are tokens that start with this terminal but need DP validation
DPCandidates []int
}
// tokenAnalysis contains precomputed terminal matches for a token
type tokenAnalysis struct {
// The token string
Token string
// TokenID in the vocabulary
TokenID int
// Matches at each byte position
// MatchesAtPos[i] = terminals matching at position i with their lengths
MatchesAtPos [][]terminalMatch
// Fast path: if token exactly matches one terminal
// -1 if no exact match
exactMatch int
// Whether this token can be consumed at all (has at least one match)
HasMatches bool
}
// analyzer precomputes terminal matches for a vocabulary
type analyzer struct {
matcher *terminalMatcher
analyses []tokenAnalysis // Indexed by token ID
vocab []string
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
// This enables direct slice appends instead of per-token branching
tokensByTerminal []terminalTokenGroups
}
// newAnalyzer creates an analyzer for the given vocabulary and terminals
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
a := &analyzer{
matcher: matcher,
analyses: make([]tokenAnalysis, len(vocab)),
vocab: vocab,
}
// Precompute analysis for each token
for i, token := range vocab {
a.analyses[i] = a.analyze(token, i)
}
// Build pre-partitioned token groups for fast ApplyMask
a.buildTokenPartitions()
return a
}
// analyze computes terminal matches for a single token
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
analysis := tokenAnalysis{
Token: token,
TokenID: tokenID,
MatchesAtPos: make([][]terminalMatch, len(token)),
exactMatch: -1,
HasMatches: false,
}
if len(token) == 0 {
return analysis
}
// Compute matches at each position
data := []byte(token)
for pos := 0; pos < len(data); pos++ {
matches := a.matcher.matchesAt(data, pos)
analysis.MatchesAtPos[pos] = matches
if len(matches) > 0 {
analysis.HasMatches = true
}
}
// Exact match is only valid when a single terminal spans the entire token
if len(analysis.MatchesAtPos) > 0 {
var exactID int = -1
for _, match := range analysis.MatchesAtPos[0] {
if match.Length != len(token) {
continue
}
if exactID >= 0 && exactID != match.TerminalID {
exactID = -1
break
}
exactID = match.TerminalID
}
analysis.exactMatch = exactID
}
return analysis
}
// analysis returns the precomputed analysis for a token ID
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
if tokenID < 0 || tokenID >= len(a.analyses) {
return tokenAnalysis{exactMatch: -1}
}
return a.analyses[tokenID]
}
// vocabSize returns the vocabulary size
func (a *analyzer) vocabSize() int {
return len(a.vocab)
}
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
// This enables ApplyMask to use direct slice appends instead of per-token branching.
func (a *analyzer) buildTokenPartitions() {
numTerminals := a.matcher.terminalCount()
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
for tokenID, analysis := range a.analyses {
if !analysis.HasMatches {
continue
}
if analysis.exactMatch >= 0 {
// Token exactly matches one terminal - fast path (O(1) validation)
tid := analysis.exactMatch
a.tokensByTerminal[tid].ExactMatches = append(
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
} else {
// Token needs DP validation - add to all terminals it can start with
// This way, when a terminal is valid, we know exactly which tokens need DP
if len(analysis.MatchesAtPos) > 0 {
seen := make(map[int]bool)
for _, match := range analysis.MatchesAtPos[0] {
tid := match.TerminalID
if !seen[tid] {
seen[tid] = true
a.tokensByTerminal[tid].DPCandidates = append(
a.tokensByTerminal[tid].DPCandidates, tokenID)
}
}
}
}
}
}
// terminalGroups returns the pre-partitioned token groups for a terminal ID
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
return terminalTokenGroups{}
}
return a.tokensByTerminal[terminalID]
}

View File

@@ -1,648 +0,0 @@
//go:build mlx
package grammar
import (
"encoding/binary"
"hash/fnv"
"sort"
"sync"
)
// visitedMapPool reduces allocations for visited maps in bridge operations
var visitedMapPool = sync.Pool{
New: func() interface{} {
return make(map[stateStackKey]bool, 16)
},
}
// getVisitedMap gets a map from the pool
func getVisitedMap() map[stateStackKey]bool {
return visitedMapPool.Get().(map[stateStackKey]bool)
}
// putVisitedMap returns a map to the pool after clearing it
func putVisitedMap(m map[stateStackKey]bool) {
for k := range m {
delete(m, k)
}
visitedMapPool.Put(m)
}
// parserConfig represents a pda state+stack combination
type parserConfig struct {
state state
Stack []stackSymbol
}
// clone creates a deep copy of the config
func (c *parserConfig) clone() *parserConfig {
newStack := make([]stackSymbol, len(c.Stack))
copy(newStack, c.Stack)
return &parserConfig{
state: c.state,
Stack: newStack,
}
}
// key returns a unique key for this config for deduplication
func (c *parserConfig) key() uint64 {
h := fnv.New64a()
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
h.Write(buf[:])
for _, sym := range c.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
return h.Sum64()
}
// configSet represents a set of parser configurations (for nondeterminism)
type configSet struct {
configs []*parserConfig
normalized bool // true if already deduplicated and sorted
cachedSig uint64 // cached signature after normalization
}
// newConfigSet creates a new config set with a single configuration
func newConfigSet(state state, stack []stackSymbol) *configSet {
return &configSet{
configs: []*parserConfig{
{state: state, Stack: stack},
},
normalized: true, // single config is already normalized
}
}
// normalize deduplicates and sorts configs for stable signatures
func (c *configSet) normalize() {
if c.normalized || len(c.configs) <= 1 {
c.normalized = true
return
}
// Deduplicate using a map
seen := make(map[uint64]*parserConfig, len(c.configs))
for _, cfg := range c.configs {
key := cfg.key()
if _, exists := seen[key]; !exists {
seen[key] = cfg
}
}
// Extract unique configs
unique := make([]*parserConfig, 0, len(seen))
for _, cfg := range seen {
unique = append(unique, cfg)
}
// Sort by key for deterministic ordering
sort.Slice(unique, func(i, j int) bool {
return unique[i].key() < unique[j].key()
})
c.configs = unique
c.normalized = true
}
// signature returns a hash for cache lookup (normalizes first)
func (c *configSet) signature() uint64 {
c.normalize()
// Return cached signature if available
if c.cachedSig != 0 {
return c.cachedSig
}
h := fnv.New64a()
// Hash number of configs
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
h.Write(buf[:])
// Hash each config (already sorted)
for _, cfg := range c.configs {
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
h.Write(buf[:])
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
h.Write(buf[:])
for _, sym := range cfg.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
}
c.cachedSig = h.Sum64()
return c.cachedSig
}
// isEmpty returns true if there are no configurations
func (c *configSet) isEmpty() bool {
return len(c.configs) == 0
}
// clone creates a deep copy of the config set
func (c *configSet) clone() *configSet {
newConfigs := make([]*parserConfig, len(c.configs))
for i, cfg := range c.configs {
newConfigs[i] = cfg.clone()
}
return &configSet{configs: newConfigs}
}
// bridge connects token analysis to pda validation
type bridge struct {
pda *pda
analyzer *analyzer
}
// newBridge creates a new bridge
func newBridge(pda *pda, analyzer *analyzer) *bridge {
return &bridge{
pda: pda,
analyzer: analyzer,
}
}
// IsTokenValid checks if token T can be consumed from the current config
// This is the main entry point for token validation
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return false
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
return b.canAcceptTerminal(config, terminal.Pattern)
}
// General path: DP over (pos, config)
return b.dpValidate(&analysis, config)
}
// canAcceptTerminal checks if any config can accept the terminal
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
for _, cfg := range config.configs {
if b.canConfigAcceptTerminal(cfg, pattern) {
return true
}
}
return false
}
// canConfigAcceptTerminal checks if a single config can accept the terminal
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
// Use pooled visited map to reduce allocations
visited := getVisitedMap()
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
putVisitedMap(visited)
return result
}
// tryAcceptTerminal recursively tries to accept a terminal from a state
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Direct match
return true
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
// Pop
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
return true
}
}
}
return false
}
// dpValidate runs DP for multi-terminal tokens
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
// state: (pos, configSet)
// Memoize by (pos, configSig)
type dpKey struct {
pos int
sig uint64
}
memo := make(map[dpKey]bool)
var dp func(pos int, config *configSet) bool
dp = func(pos int, config *configSet) bool {
if pos == len(analysis.Token) {
return true // Consumed entire token
}
if config.isEmpty() {
return false
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
memo[key] = true
return true
}
}
memo[key] = false
return false
}
return dp(0, startConfig)
}
// advanceConfig advances all configs that can accept the terminal
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
var newConfigs []*parserConfig
for _, cfg := range config.configs {
advanced := b.advanceSingleConfig(cfg, pattern)
newConfigs = append(newConfigs, advanced...)
}
if len(newConfigs) == 0 {
return nil
}
return &configSet{configs: newConfigs}
}
// advanceSingleConfig advances a single config by accepting a terminal
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
var results []*parserConfig
visited := getVisitedMap()
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
putVisitedMap(visited)
return results
}
// collectAdvanced collects all configs reachable by accepting the pattern
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Match! Create new config after transition
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
*results = append(*results, &parserConfig{
state: t.ToState,
Stack: newStack,
})
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
}
}
}
// validTokens returns all token IDs that are valid from the given config
func (b *bridge) validTokens(config *configSet) []int {
var valid []int
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
if b.IsTokenValid(tokenID, config) {
valid = append(valid, tokenID)
}
}
return valid
}
// acceptToken attempts to accept a token and returns the new config set
// Returns nil if the token is not valid from this config
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return nil
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
newConfig.normalize()
return newConfig
}
return nil
}
// General path: DP to find final config after consuming token
return b.dpAccept(&analysis, config)
}
// dpAccept runs DP to accept a multi-terminal token and return final config
// Returns the union of all possible end configurations (preserves nondeterminism)
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
type dpKey struct {
pos int
sig uint64
}
// Memoize the configs reachable at each (pos, sig)
memo := make(map[dpKey]*configSet)
var dp func(pos int, config *configSet) *configSet
dp = func(pos int, config *configSet) *configSet {
if pos == len(analysis.Token) {
return config // Consumed entire token, return final config
}
if config.isEmpty() {
return nil
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Collect all valid result configs from all possible paths
var allConfigs []*parserConfig
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
finalConfig := dp(pos+match.Length, newConfig)
if finalConfig != nil {
// Collect all configs, don't return early
allConfigs = append(allConfigs, finalConfig.configs...)
}
}
}
// Build result: nil if no valid paths, normalized configSet otherwise
var result *configSet
if len(allConfigs) > 0 {
result = &configSet{configs: allConfigs}
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
}
memo[key] = result // Cache normalized result
return result
}
return dp(0, startConfig)
}
// isAccepting returns true if any config can reach an accepting state
func (b *bridge) isAccepting(config *configSet) bool {
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config check
for k := range visited {
delete(visited, k)
}
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
return true
}
}
return false
}
// canReachAccept checks if we can reach an accepting state via epsilon transitions
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if b.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if b.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// validTerminals returns the valid terminal patterns from the given config
func (b *bridge) validTerminals(config *configSet) []string {
seen := make(map[string]bool)
var terminals []string
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
}
return terminals
}
// collectValidTerminals collects all reachable terminals
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" && !seen[t.Pattern] {
seen[t.Pattern] = true
*terminals = append(*terminals, t.Pattern)
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
}
}
}
// validTerminalIDs returns the IDs of valid terminals from the given config
func (b *bridge) validTerminalIDs(config *configSet) []int {
seen := make(map[int]bool)
var terminalIDs []int
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
}
return terminalIDs
}
// collectValidTerminalIDs collects IDs of all reachable terminals
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" {
// Look up terminal ID from pattern
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
seen[tid] = true
*terminalIDs = append(*terminalIDs, tid)
}
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
}
}
}

View File

@@ -1,45 +0,0 @@
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
id-field ::= "\"id\"" ws ":" ws uuid
kind-field ::= "\"kind\"" ws ":" ws kind
items-field ::= "\"items\"" ws ":" ws items
alt-field ::= "\"alt\"" ws ":" ws alt
flags-field ::= "\"flags\"" ws ":" ws flags
meta-field ::= "\"meta\"" ws ":" ws meta
priority-field ::= "\"priority\"" ws ":" ws int
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
source ::= "\"api\"" | "\"batch\"" | "\"import\""
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
item-sku ::= "\"sku\"" ws ":" ws string
item-qty ::= "\"qty\"" ws ":" ws int
item-status ::= "\"status\"" ws ":" ws status
item-notes ::= "\"notes\"" ws ":" ws string
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
meta-created ::= "\"created\"" ws ":" ws date-time
meta-source ::= "\"source\"" ws ":" ws source
meta-ip ::= "\"ip\"" ws ":" ws ipv4
alt ::= string | int | "null"
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
int ::= "-"? digit+
digit ::= [0-9]
hex ::= [0-9a-fA-F]
ws ::= [ \t\n\r]*

View File

@@ -1,46 +0,0 @@
{
"type": "object",
"properties": {
"id": { "type": "string", "format": "uuid" },
"kind": { "enum": ["order", "invoice", "shipment"] },
"items": {
"type": "array",
"minItems": 1,
"maxItems": 3,
"items": {
"type": "object",
"properties": {
"sku": { "type": "string" },
"qty": { "type": "integer" },
"status": { "enum": ["new", "backorder", "shipped"] },
"notes": { "type": "string" }
},
"required": ["sku", "qty", "status", "notes"]
}
},
"alt": {
"oneOf": [
{ "type": "string" },
{ "type": "null" },
{ "type": "integer" }
]
},
"flags": {
"type": "array",
"minItems": 0,
"maxItems": 4,
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
},
"meta": {
"type": "object",
"properties": {
"created": { "type": "string", "format": "date-time" },
"source": { "enum": ["api", "batch", "import"] },
"ip": { "type": "string", "format": "ipv4" }
},
"required": ["created", "source", "ip"]
},
"priority": { "type": "integer" }
},
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
}

View File

@@ -1,235 +0,0 @@
//go:build mlx
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"time"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
"github.com/ollama/ollama/x/imagegen/mlx"
)
const jsonGBNF = `
root ::= value
value ::= object | array | string | number | "true" | "false" | "null"
object ::= "{" ws "}" | "{" members "}"
members ::= member ("," member)*
member ::= ws string ws ":" element
array ::= "[" ws "]" | "[" elements "]"
elements ::= element ("," element)*
element ::= ws value ws
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
number ::= "-"? integer fraction? exponent?
integer ::= "0" | [1-9] [0-9]*
fraction ::= "." [0-9]+
exponent ::= [eE] [+-]? [0-9]+
ws ::= [ \t\n\r]*
`
type result struct {
vocabSize int `json:"vocab_size"`
Iterations int `json:"iterations"`
Warmup int `json:"warmup"`
ConstrainedSource string `json:"constrained_source"`
LlamaSource string `json:"llama_source"`
LlamaApply string `json:"llama_apply"`
ConstrainedGraph string `json:"constrained_graph"`
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
EvalOnly string `json:"eval_only,omitempty"`
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
}
func main() {
var (
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
iterations = flag.Int("iterations", 500, "Benchmark iterations")
warmup = flag.Int("warmup", 50, "Warmup iterations")
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
startRule = flag.String("start", "root", "Start rule for EBNF")
)
flag.Parse()
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
fmt.Fprintln(os.Stderr, "invalid flags")
os.Exit(2)
}
vocab := createVocab(*vocabSize)
if *schemaPath != "" && *ebnfPath != "" {
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
os.Exit(2)
}
var constrainedSource string
var compiled *grammar.Grammar
var err error
switch {
case *schemaPath != "":
data, readErr := os.ReadFile(*schemaPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
os.Exit(1)
}
compiled, err = schema.Grammar(string(data))
constrainedSource = "schema:" + *schemaPath
case *ebnfPath != "":
data, readErr := os.ReadFile(*ebnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
os.Exit(1)
}
compiled, err = grammar.ParseEBNF(string(data), *startRule)
constrainedSource = "ebnf:" + *ebnfPath
default:
compiled, err = grammar.JSONGrammar()
constrainedSource = "json"
}
if err != nil {
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
os.Exit(1)
}
engine, err := grammar.NewEngine(compiled, vocab)
if err != nil {
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
os.Exit(1)
}
defer engine.Close()
logits := mlx.Ones(int32(*vocabSize))
mlx.Keep(logits)
for i := 0; i < *warmup; i++ {
masked := engine.ApplyMask(logits)
if *withEval {
mlx.Eval(masked)
}
}
graphAvg := measure(*iterations, func() {
_ = engine.ApplyMask(logits)
})
var evalAvg time.Duration
var evalOnlyAvg time.Duration
if *withEval {
evalOnlyAvg = measure(*iterations, func() {
baseline := mlx.MulScalar(logits, 1)
mlx.Eval(baseline)
baseline.Free()
})
evalAvg = measure(*iterations, func() {
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
})
}
vocabIDs := make([]uint32, *vocabSize)
for i := range vocabIDs {
vocabIDs[i] = uint32(i)
}
eogTokens := []int32{0}
gbnf := jsonGBNF
llamaSource := "json"
if *gbnfPath != "" {
data, readErr := os.ReadFile(*gbnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
os.Exit(1)
}
gbnf = string(data)
llamaSource = *gbnfPath
}
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
if llamaGrammar == nil {
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
os.Exit(1)
}
defer llamaGrammar.Free()
llamaTokens := make([]llama.TokenData, *vocabSize)
for i := 0; i < *warmup; i++ {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
}
llamaAvg := measure(*iterations, func() {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
})
out := result{
vocabSize: *vocabSize,
Iterations: *iterations,
Warmup: *warmup,
LlamaApply: llamaAvg.String(),
ConstrainedGraph: graphAvg.String(),
ConstrainedSource: constrainedSource,
LlamaSource: llamaSource,
}
if *withEval {
out.ConstrainedWithEval = evalAvg.String()
out.EvalOnly = evalOnlyAvg.String()
if evalAvg > evalOnlyAvg {
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
} else {
out.ConstrainedEvalNet = "0s"
}
}
enc := json.NewEncoder(os.Stdout)
if err := enc.Encode(out); err != nil {
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
os.Exit(1)
}
}
func measure(iterations int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < iterations; i++ {
fn()
}
return time.Since(start) / time.Duration(iterations)
}
func createVocab(size int) []string {
vocab := make([]string, size)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"",
}
for i, t := range jsonTokens {
if i < size {
vocab[i] = t
}
}
for i := len(jsonTokens); i < size; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}

View File

@@ -1,320 +0,0 @@
//go:build mlx
package grammar
import (
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// Grammar is the compiled form of an EBNF grammar.
// It contains terminals, parse tables, and the start state.
// Use ParseEBNF or JSONGrammar to create a Grammar.
type Grammar struct {
// The underlying pda
pda *pda
// Compiled terminal matcher
matcher *terminalMatcher
}
// ParseEBNF compiles an EBNF grammar string into a Grammar.
// startRule is the name of the start rule (e.g., "root", "json").
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
pda, err := compileString(ebnf, startRule)
if err != nil {
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
}
matcher, err := compileTerminalsStrict(pda)
if err != nil {
return nil, fmt.Errorf("failed to compile terminals: %w", err)
}
return &Grammar{
pda: pda,
matcher: matcher,
}, nil
}
// JSONGrammar returns the compiled JSON grammar.
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
func JSONGrammar() (*Grammar, error) {
return ParseEBNF(JSONGrammarEBNF, "json")
}
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
// Use this when you want to ensure the output is a JSON object (starts with {).
func JSONObjectGrammar() (*Grammar, error) {
return ParseEBNF(JSONObjectGrammarEBNF, "json")
}
// compileTerminalsStrict builds a matcher that properly handles:
// - Escaped literals ("\n", \"", \uXXXX)
// - Unicode ranges (rune-based, not byte-based)
// - Rejects unsupported patterns with an error (no silent fallback)
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
m := &terminalMatcher{
literalTrie: &trieNode{terminalID: -1},
ranges: make([]terminal, 0),
terminals: make([]terminal, 0, len(pda.Terminals)),
patternToID: make(map[string]int),
}
// Track which pattern produced each unescaped value for collision detection
unescapedSource := make(map[string]string) // unescaped -> original pattern
for i, pattern := range pda.Terminals {
terminal, err := parseTerminalPattern(pattern, i)
if err != nil {
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
}
if terminal.Type == terminalLiteral {
// Use the unescaped pattern for trie matching
m.addLiteralToTrie(terminal.Unescaped, i)
// Detect collisions between literals that unescape to the same value
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
if existingPattern != pattern {
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
existingPattern, pattern, terminal.Unescaped)
}
} else {
unescapedSource[terminal.Unescaped] = pattern
}
} else if terminal.Type == terminalRange {
m.ranges = append(m.ranges, terminal)
}
m.terminals = append(m.terminals, terminal)
m.patternToID[pattern] = i
}
return m, nil
}
// parseTerminalPattern parses a terminal pattern and returns a terminal.
// Supports:
// - Literal strings (with escape sequences)
// - Character ranges [X-Y] (unicode-aware)
func parseTerminalPattern(pattern string, id int) (terminal, error) {
if len(pattern) == 0 {
return terminal{}, fmt.Errorf("empty pattern")
}
// Check for range pattern: [X-Y]
if isUnicodeRangePattern(pattern) {
lowRune, highRune, err := parseUnicodeRange(pattern)
if err != nil {
return terminal{}, err
}
return terminal{
ID: id,
Type: terminalRange,
Pattern: pattern,
Unescaped: pattern,
LowRune: lowRune,
HighRune: highRune,
}, nil
}
// It's a literal - unescape it
unescaped, err := unescapeLiteral(pattern)
if err != nil {
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
}
return terminal{
ID: id,
Type: terminalLiteral,
Pattern: pattern,
Unescaped: unescaped,
}, nil
}
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
func isUnicodeRangePattern(pattern string) bool {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return false
}
// Find the dash that separates low-high
inner := pattern[1 : len(pattern)-1]
dashIdx := strings.Index(inner, "-")
// Handle escaped dash at start
if dashIdx <= 0 {
return false
}
return true
}
// parseUnicodeRange parses [X-Y] into low and high runes
func parseUnicodeRange(pattern string) (rune, rune, error) {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return 0, 0, fmt.Errorf("invalid range pattern")
}
inner := pattern[1 : len(pattern)-1]
// Simple case: [a-z] where a and z are single chars
if len(inner) == 3 && inner[1] == '-' {
return rune(inner[0]), rune(inner[2]), nil
}
// Handle escaped characters like [\u0000-\uFFFF]
dashIdx := findRangeDash(inner)
if dashIdx < 0 {
return 0, 0, fmt.Errorf("no dash in range")
}
lowStr := inner[:dashIdx]
highStr := inner[dashIdx+1:]
lowRune, err := parseRune(lowStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
}
highRune, err := parseRune(highStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
}
if lowRune > highRune {
return 0, 0, fmt.Errorf("low bound > high bound")
}
return lowRune, highRune, nil
}
// findRangeDash finds the dash separating low-high in a range pattern
func findRangeDash(inner string) int {
i := 0
for i < len(inner) {
if inner[i] == '\\' && i+1 < len(inner) {
// Skip escape sequence
if inner[i+1] == 'u' && i+6 <= len(inner) {
i += 6 // \uXXXX
} else {
i += 2 // \n, \t, etc.
}
continue
}
if inner[i] == '-' && i > 0 {
return i
}
i++
}
return -1
}
// parseRune parses a single rune from a string (handles escapes)
func parseRune(s string) (rune, error) {
if len(s) == 0 {
return 0, fmt.Errorf("empty rune")
}
// Handle escape sequences
if s[0] == '\\' {
if len(s) < 2 {
return 0, fmt.Errorf("incomplete escape")
}
switch s[1] {
case 'n':
return '\n', nil
case 't':
return '\t', nil
case 'r':
return '\r', nil
case '\\':
return '\\', nil
case '"':
return '"', nil
case '\'':
return '\'', nil
case 'u':
if len(s) < 6 {
return 0, fmt.Errorf("incomplete unicode escape")
}
val, err := strconv.ParseInt(s[2:6], 16, 32)
if err != nil {
return 0, fmt.Errorf("invalid unicode escape: %w", err)
}
return rune(val), nil
default:
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
}
}
// Plain character
r, _ := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return 0, fmt.Errorf("invalid utf8")
}
return r, nil
}
// unescapeLiteral unescapes a literal pattern string
func unescapeLiteral(pattern string) (string, error) {
// Try strconv.Unquote if it looks quoted
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
unquoted, err := strconv.Unquote(pattern)
if err != nil {
return "", err
}
return unquoted, nil
}
// If no backslashes, return as-is
if !strings.Contains(pattern, "\\") {
return pattern, nil
}
// Manual unescape
var result strings.Builder
i := 0
for i < len(pattern) {
if pattern[i] == '\\' && i+1 < len(pattern) {
switch pattern[i+1] {
case 'n':
result.WriteByte('\n')
i += 2
case 't':
result.WriteByte('\t')
i += 2
case 'r':
result.WriteByte('\r')
i += 2
case '\\':
result.WriteByte('\\')
i += 2
case '"':
result.WriteByte('"')
i += 2
case '\'':
result.WriteByte('\'')
i += 2
case 'u':
if i+6 <= len(pattern) {
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
if err != nil {
return "", fmt.Errorf("invalid unicode escape at %d", i)
}
result.WriteRune(rune(val))
i += 6
} else {
return "", fmt.Errorf("incomplete unicode escape at %d", i)
}
default:
// Reject unknown escape sequences
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
}
} else {
result.WriteByte(pattern[i])
i++
}
}
return result.String(), nil
}

View File

@@ -1,329 +0,0 @@
//go:build mlx
package grammar
import (
"container/list"
"fmt"
"math"
"sync"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// maskCache provides LRU caching for computed masks.
type maskCache struct {
cache map[uint64]*list.Element
order *list.List
maxSize int
mu sync.Mutex
}
type maskEntry struct {
sig uint64
mask *mlx.Array
}
// newMaskCache creates a new mask cache with the given max size
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
func newMaskCache(maxSize int) *maskCache {
if maxSize <= 0 {
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: 0, // Signals disabled
}
}
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: maxSize,
}
}
// get retrieves a cached mask, returning nil if not found.
// Updates LRU order on cache hit.
func (c *maskCache) get(sig uint64) *mlx.Array {
if c.maxSize <= 0 {
return nil // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, ok := c.cache[sig]; ok {
c.order.MoveToFront(elem)
return elem.Value.(*maskEntry).mask
}
return nil
}
// put stores a mask in the cache with LRU eviction.
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
if c.maxSize <= 0 {
return // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, exists := c.cache[sig]; exists {
c.order.MoveToFront(elem)
return
}
// Evict oldest if at capacity (safe since maxSize > 0)
if c.order.Len() >= c.maxSize {
oldest := c.order.Back()
if oldest != nil {
entry := oldest.Value.(*maskEntry)
entry.mask.Free()
delete(c.cache, entry.sig)
c.order.Remove(oldest)
}
}
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
c.cache[sig] = elem
}
// clear frees all cached masks.
func (c *maskCache) clear() {
c.mu.Lock()
defer c.mu.Unlock()
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
elem.Value.(*maskEntry).mask.Free()
}
c.cache = make(map[uint64]*list.Element)
c.order.Init()
}
// size returns the number of cached masks.
func (c *maskCache) size() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.cache)
}
// Engine applies grammar constraints to model outputs using MLX.
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
type Engine struct {
// The compiled grammar
grammar *Grammar
// bridge for token validation
bridge *bridge
analyzer *analyzer
// Current parser state (configSet for nondeterminism)
configSet *configSet
// Token vocabulary from the model
vocab []string
tokenToID map[string]int // O(1) lookup for AcceptString
// Mask cache: configSig → valid token mask (LRU)
maskCache *maskCache
// Cached negative infinity mask for invalid tokens
negInfMask *mlx.Array
// Threshold for comparison (0.5 since mask values are 0 or 1)
threshold *mlx.Array
// Vocabulary size
vocabSize int32
// Reusable buffers for candidate filtering (avoid allocations)
candidateMark []bool // indexed by tokenID, true if in candidate set
touched []int // tokenIDs that were marked (for reset)
dpCandidates []int // candidates requiring DP validation
// Reusable buffer for valid token indices (for GPU scatter)
validTokenIDs []int32
}
// EngineOption configures an Engine
type EngineOption func(*Engine)
// WithMaskCacheSize sets the mask cache size (default 1024)
func WithMaskCacheSize(size int) EngineOption {
return func(e *Engine) {
e.maskCache = newMaskCache(size)
}
}
// NewEngine creates a new constrained decoding engine.
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
// vocab is the list of token strings from the model's tokenizer.
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
if grammar == nil {
return nil, fmt.Errorf("grammar cannot be nil")
}
// Build analyzer and bridge
analyzer := newAnalyzer(vocab, grammar.matcher)
bridge := newBridge(grammar.pda, analyzer)
// Initialize config set from pda initial state
initialConfig := newConfigSet(grammar.pda.StartState, nil)
// Build token lookup map for O(1) AcceptString
tokenToID := make(map[string]int, len(vocab))
for i, tok := range vocab {
tokenToID[tok] = i
}
e := &Engine{
grammar: grammar,
bridge: bridge,
analyzer: analyzer,
configSet: initialConfig,
vocab: vocab,
tokenToID: tokenToID,
maskCache: newMaskCache(1024),
vocabSize: int32(len(vocab)),
candidateMark: make([]bool, len(vocab)),
touched: make([]int, 0, 10000),
validTokenIDs: make([]int32, 0, 10000),
}
// Apply options
for _, opt := range opts {
opt(e)
}
// Create the negative infinity mask and threshold
if e.vocabSize > 0 {
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
mlx.Keep(e.negInfMask)
e.threshold = mlx.NewScalarArray(0.5)
mlx.Keep(e.threshold)
}
return e, nil
}
// ApplyMask applies grammar constraints to logits.
// Returns logits with invalid tokens set to -inf.
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
sig := e.configSet.signature()
// Check state cache first (exact state match)
if cached := e.maskCache.get(sig); cached != nil {
condition := mlx.GreaterEqual(cached, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Compute valid tokens using candidate filtering:
// 1. Get valid terminal IDs from current grammar state
// 2. Get candidate tokens (those that START with valid terminals)
// 3. Run DP validation only on candidates
// This is O(candidates) instead of O(vocab_size)
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
// Use pre-partitioned token groups for fast candidate building
// This eliminates per-token branching - just direct slice appends
e.validTokenIDs = e.validTokenIDs[:0]
e.dpCandidates = e.dpCandidates[:0]
e.touched = e.touched[:0]
for _, tid := range validTerminalIDs {
groups := e.analyzer.terminalGroups(tid)
// Direct append of exact matches (no per-token check needed)
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
// Collect DP candidates (may have duplicates across terminals)
for _, tokenID := range groups.DPCandidates {
if !e.candidateMark[tokenID] {
e.candidateMark[tokenID] = true
e.dpCandidates = append(e.dpCandidates, tokenID)
e.touched = append(e.touched, tokenID)
}
}
}
// Reset marks for next call
for _, id := range e.touched {
e.candidateMark[id] = false
}
for _, tokenID := range e.dpCandidates {
if e.bridge.IsTokenValid(tokenID, e.configSet) {
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
}
}
// Create and cache the mask on GPU using index updates
mask := mlx.Zeros([]int32{e.vocabSize})
if len(e.validTokenIDs) > 0 {
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
values := mlx.Ones(int32(len(e.validTokenIDs)))
mask = mlx.PutAlongAxis(mask, indices, values, 0)
}
mlx.Keep(mask)
// Cache by state signature
e.maskCache.put(sig, mask)
// Apply mask
condition := mlx.GreaterEqual(mask, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Accept processes a token and updates the parser state.
// Returns true if the token was valid and accepted.
func (e *Engine) Accept(tokenID int) bool {
if tokenID < 0 || tokenID >= len(e.vocab) {
return false
}
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
if newConfig == nil {
return false
}
e.configSet = newConfig
return true
}
// AcceptString processes a token string directly.
// Returns true if the token was valid and accepted.
func (e *Engine) AcceptString(token string) bool {
if id, ok := e.tokenToID[token]; ok {
return e.Accept(id)
}
return false
}
// IsComplete returns true if the current state is accepting.
func (e *Engine) IsComplete() bool {
return e.bridge.isAccepting(e.configSet)
}
// Reset resets the engine to initial state.
func (e *Engine) Reset() {
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
}
// validTokens returns the indices of tokens that are currently valid.
func (e *Engine) validTokens() []int {
return e.bridge.validTokens(e.configSet)
}
// validTerminals returns the valid terminal patterns from the current state.
func (e *Engine) validTerminals() []string {
return e.bridge.validTerminals(e.configSet)
}
// Close releases MLX resources.
func (e *Engine) Close() {
if e.maskCache != nil {
e.maskCache.clear()
}
if e.negInfMask != nil {
e.negInfMask.Free()
}
if e.threshold != nil {
e.threshold.Free()
}
}

View File

@@ -1,414 +0,0 @@
//go:build mlx
package grammar
import (
"fmt"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newBenchEngine creates a JSON engine for benchmarks
func newBenchEngine(b *testing.B, vocab []string) *Engine {
b.Helper()
grammar, err := JSONGrammar()
if err != nil {
b.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
b.Fatalf("failed to create engine: %v", err)
}
return e
}
// Vocabulary sizes to test (matching real models)
var vocabSizes = []int{
32000, // Llama 2
128000, // Llama 3
256000, // Large models
}
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
func createBenchVocabN(n int) []string {
vocab := make([]string, n)
// JSON structural tokens (first 20)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"", "'",
}
for i, t := range jsonTokens {
if i < n {
vocab[i] = t
}
}
// String tokens (indices 20-1000)
stringIdx := 20
for i := 0; i < 980 && stringIdx+i < n; i++ {
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
}
// Number tokens (indices 1000-2000)
numberIdx := 1000
for i := 0; i < 1000 && numberIdx+i < n; i++ {
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
}
// Generic tokens (rest)
for i := 2000; i < n; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}
// ============ Core Performance Benchmarks ============
// BenchmarkApplyMask_32k measures mask application with 32k vocab
func BenchmarkApplyMask_32k(b *testing.B) {
benchmarkApplyMask(b, 32000)
}
// BenchmarkApplyMask_128k measures mask application with 128k vocab
func BenchmarkApplyMask_128k(b *testing.B) {
benchmarkApplyMask(b, 128000)
}
// BenchmarkApplyMask_256k measures mask application with 256k vocab
func BenchmarkApplyMask_256k(b *testing.B) {
benchmarkApplyMask(b, 256000)
}
func benchmarkApplyMask(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ReportMetric(float64(vocabSize), "vocab_size")
}
// ============ state-Dependent Benchmarks ============
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
e.AcceptString("{")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// BenchmarkApplyMaskMidObject measures mask in middle of object
func BenchmarkApplyMaskMidObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
// state: {"key": _value_
e.AcceptString("{")
e.AcceptString("\"key\"")
e.AcceptString(":")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ Token Sequence Benchmarks ============
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
func BenchmarkSequence_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
func BenchmarkSequence_NestedObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
func BenchmarkSequence_LargeArray(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Build sequence: [1, 2, 3, ..., 50]
sequence := []string{"["}
for i := 1; i <= 50; i++ {
sequence = append(sequence, fmt.Sprintf("%d", i))
if i < 50 {
sequence = append(sequence, ",")
}
}
sequence = append(sequence, "]")
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
func BenchmarkSequence_MixedTypes(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{",
"\"name\"", ":", "\"test\"", ",",
"\"count\"", ":", "42", ",",
"\"enabled\"", ":", "true", ",",
"\"data\"", ":", "null", ",",
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
"}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// ============ Component Benchmarks ============
// BenchmarkValidInputs measures pda valid input computation
func BenchmarkValidInputs(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.validTerminals()
}
}
// BenchmarkStateTransition measures pda state transition
func BenchmarkStateTransition(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
e.AcceptString(token)
}
}
}
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // Graph only, no eval
}
}
// BenchmarkNewEngine measures one-time engine initialization.
func BenchmarkNewEngine_32k(b *testing.B) {
benchmarkNewEngine(b, 32000)
}
func BenchmarkNewEngine_128k(b *testing.B) {
benchmarkNewEngine(b, 128000)
}
func benchmarkNewEngine(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
e := newBenchEngine(b, vocab)
e.Close()
}
}
// ============ Memory Benchmarks ============
func BenchmarkMemoryAllocs_32k(b *testing.B) {
benchmarkMemoryAllocs(b, 32000)
}
func BenchmarkMemoryAllocs_128k(b *testing.B) {
benchmarkMemoryAllocs(b, 128000)
}
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
// This simulates adding mask to LLM compute graph
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // No Eval - just build graph
}
}
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
var lastMasked *mlx.Array
for _, token := range sequence {
lastMasked = e.ApplyMask(logits) // Build graph only
e.AcceptString(token)
}
mlx.Eval(lastMasked) // Single eval at end
}
b.ReportMetric(float64(len(sequence)), "tokens")
}

View File

@@ -1,689 +0,0 @@
//go:build mlx
package grammar
import (
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newTestEngine creates a JSON engine for testing
func newTestEngine(t testing.TB, vocab []string) *Engine {
t.Helper()
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
return e
}
// Mock vocabulary for testing
func testVocab() []string {
return []string{
"{", // 0: object start
"}", // 1: object end
"[", // 2: array start
"]", // 3: array end
":", // 4: colon
",", // 5: comma
"\"key\"", // 6: string (quoted)
"\"val\"", // 7: string (quoted)
"123", // 8: number
"-42.5", // 9: number
"true", // 10: boolean
"false", // 11: boolean
"null", // 12: null
" ", // 13: whitespace (should be ignored)
"\n", // 14: whitespace (should be ignored)
"subword", // 15: bare word (NOT valid JSON - requires quotes)
"hello", // 16: bare word (NOT valid JSON - requires quotes)
}
}
func TestNewEngine(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != int32(len(vocab)) {
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
}
// Verify grammar is set
if e.grammar == nil {
t.Error("grammar should not be nil")
}
// Verify analyzer is set
if e.analyzer == nil {
t.Error("analyzer should not be nil")
}
}
func TestEngineValidTokens(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// At start, any value type should be valid
validTokens := e.validTokens()
// Should include object start, array start, strings, numbers, booleans, null
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
// (JSON strings must be quoted)
expectedTokens := map[int]bool{
0: true, // {
2: true, // [
6: true, // "key"
7: true, // "val"
8: true, // 123
9: true, // -42.5
10: true, // true
11: true, // false
12: true, // null
}
// Check that expected tokens are present
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
for idx := range expectedTokens {
if !validSet[idx] {
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
}
}
if validSet[15] || validSet[16] {
t.Error("bare words should not be valid JSON at the start state")
}
}
func TestEngineAccept(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { should work
if !e.Accept(0) { // {
t.Error("should accept {")
}
// After {, valid tokens should be STRING or }
validTokens := e.validTokens()
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
// STRING tokens (indices 6, 7) and } (index 1) should be valid
if !validSet[1] {
t.Error("} should be valid after {")
}
if !validSet[6] && !validSet[7] {
t.Error("STRING should be valid after { (for keys)")
}
}
func TestEngineAcceptSequence(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept {"key": "val"}
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
for i, tokenID := range sequence {
if !e.Accept(tokenID) {
t.Fatalf("failed to accept token %d (%s) at position %d",
tokenID, vocab[tokenID], i)
}
}
if !e.IsComplete() {
t.Error("should be in complete state after valid JSON")
}
}
func TestEngineReset(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept some tokens
e.Accept(0) // {
e.Accept(1) // }
if !e.IsComplete() {
t.Error("should be complete after {}")
}
// Reset
e.Reset()
// Should be back to initial state
if e.IsComplete() {
t.Error("should not be complete after reset")
}
// Should be able to accept new sequence
if !e.Accept(0) { // {
t.Error("should accept { after reset")
}
}
func TestEngineInvalidTokenRejection(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { first
if !e.Accept(0) {
t.Fatal("should accept {")
}
// Now try to accept [ which is invalid after {
// (After {, only STRING or } are valid)
if e.Accept(2) { // [
t.Error("should not accept [ after { (expecting STRING or })")
}
}
func TestEngineAcceptString(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept using string directly
if !e.AcceptString("{") {
t.Error("should accept {")
}
if !e.AcceptString("\"key\"") {
t.Error("should accept string key")
}
if !e.AcceptString(":") {
t.Error("should accept :")
}
if !e.AcceptString("123") {
t.Error("should accept number")
}
if !e.AcceptString("}") {
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after valid JSON")
}
}
func TestJSONBackslashEscape(t *testing.T) {
vocab := []string{`"`, `\`, "n", "a"}
e := newTestEngine(t, vocab)
defer e.Close()
// Valid escape: "\n"
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if !e.AcceptString("n") {
t.Fatal("should accept escape code")
}
if !e.AcceptString(`"`) {
t.Fatal("should accept string end")
}
if !e.IsComplete() {
t.Error("should be complete after escaped string")
}
// Invalid escape: "\a"
e.Reset()
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if e.AcceptString("a") {
t.Error("should reject invalid escape code")
}
}
func TestEngineNegInfMask(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Verify negInfMask exists and has correct shape
if e.negInfMask == nil {
t.Fatal("negInfMask should not be nil")
}
}
func TestEngineMaskCache(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Create test logits
logits := mlx.Ones(int32(len(vocab)))
// Apply mask - should populate cache
_ = e.ApplyMask(logits)
// Check cache was populated
cacheSize := e.maskCache.size()
if cacheSize == 0 {
t.Error("mask cache should have at least one entry after ApplyMask")
}
}
func TestEngineEmptyVocab(t *testing.T) {
e := newTestEngine(t, []string{})
defer e.Close()
if e.vocabSize != 0 {
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
}
}
func TestEngineLargeVocab(t *testing.T) {
// Create a large vocabulary (simulating real model vocab)
vocab := make([]string, 32000)
for i := range vocab {
vocab[i] = "token"
}
// Add some actual JSON tokens
vocab[0] = "{"
vocab[1] = "}"
vocab[2] = "["
vocab[3] = "]"
vocab[4] = ":"
vocab[5] = ","
vocab[6] = "\"test\""
vocab[7] = "123"
vocab[8] = "true"
vocab[9] = "false"
vocab[10] = "null"
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != 32000 {
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
}
// Test that it still works correctly
if !e.Accept(0) { // {
t.Error("should accept {")
}
if !e.Accept(1) { // }
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after {}")
}
}
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
func TestE2E_JSONDecoding(t *testing.T) {
// Create a realistic vocabulary with JSON tokens
vocab := []string{
// Structural tokens
"{", "}", "[", "]", ":", ",",
// Keywords
"true", "false", "null",
// Quoted strings
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
`"hello"`, `"world"`, `"test"`,
// Numbers
"0", "1", "2", "3", "42", "123", "-1", "-42",
// Whitespace
" ", "\n", "\t",
// Multi-terminal tokens (span multiple JSON lexemes)
`"key":`, `},`, `],`, `{"`, `["`,
// Partial/invalid tokens (should be rejected)
"invalid", "foo", "bar",
}
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
// Simple values
{"empty object", []string{"{", "}"}, true},
{"empty array", []string{"[", "]"}, true},
{"true literal", []string{"true"}, true},
{"null literal", []string{"null"}, true},
{"number", []string{"42"}, true},
{"negative number", []string{"-42"}, true},
{"quoted string", []string{`"hello"`}, true},
// Objects
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
// Arrays
{"array of numbers", []string{"[", "42", "]"}, true},
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
// Nested structures
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
// Invalid sequences
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
{"bare word", []string{"invalid"}, false}, // not valid JSON
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
// Process each token
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return // Already reported error
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
// For invalid sequences, we expect either rejection or incomplete
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
// Simple expression grammar: expr = term { ("+" | "-") term }
// term = number | "(" expr ")"
// number = digit { digit }
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
exprGrammar := `
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
`
grammar, err := ParseEBNF(exprGrammar, "expr")
if err != nil {
t.Fatalf("failed to parse expression grammar: %v", err)
}
// Vocabulary for expression tokens
vocab := []string{
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "*", "/",
"(", ")",
// Multi-digit numbers as single tokens
"10", "42", "100", "123",
// Invalid tokens
"x", "y", "invalid",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single digit", []string{"5"}, true},
{"multi-digit", []string{"1", "2", "3"}, true},
{"addition", []string{"1", "+", "2"}, true},
{"subtraction", []string{"5", "-", "3"}, true},
{"multiplication", []string{"2", "*", "3"}, true},
{"division", []string{"8", "/", "2"}, true},
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
// Invalid
{"just operator", []string{"+"}, false},
{"double operator", []string{"1", "+", "+", "2"}, false},
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
{"variable", []string{"x"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
func TestE2E_IdentifierGrammar(t *testing.T) {
// Identifier grammar using character ranges
identGrammar := `
ident = letter { letter | digit } .
letter = "a" … "z" | "A" … "Z" | "_" .
digit = "0" … "9" .
`
grammar, err := ParseEBNF(identGrammar, "ident")
if err != nil {
t.Fatalf("failed to parse identifier grammar: %v", err)
}
// Vocabulary with letters and digits
vocab := []string{
"a", "b", "c", "x", "y", "z",
"A", "B", "C", "X", "Y", "Z",
"_",
"0", "1", "2", "9",
// Multi-char tokens
"foo", "bar", "myVar", "test123",
// Invalid starting chars
"1abc", "123",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single letter", []string{"a"}, true},
{"uppercase", []string{"A"}, true},
{"underscore", []string{"_"}, true},
{"multi-letter", []string{"a", "b", "c"}, true},
{"letter then digit", []string{"x", "1"}, true},
{"underscore prefix", []string{"_", "a", "1"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass && allAccepted && !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
})
}
}
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
func TestE2E_UnicodeRange(t *testing.T) {
greekGrammar := `
greek = "α" … "ω" .
`
grammar, err := ParseEBNF(greekGrammar, "greek")
if err != nil {
t.Fatalf("failed to parse unicode grammar: %v", err)
}
vocab := []string{"α", "β", "ω", "a"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
if !engine.AcceptString("β") {
t.Error("should accept beta")
}
if !engine.IsComplete() {
t.Error("should be complete after single rune")
}
engine.Reset()
if engine.AcceptString("a") {
t.Error("should reject ASCII outside unicode range")
}
}
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
func TestE2E_NondeterminismPreserved(t *testing.T) {
// This grammar has nondeterminism: "ab" could be parsed as
// a single token or as two tokens "a" "b"
ambiguousGrammar := `
start = item item .
item = "a" | "b" | "ab" .
`
grammar, err := ParseEBNF(ambiguousGrammar, "start")
if err != nil {
t.Fatalf("failed to parse grammar: %v", err)
}
// Vocabulary with both single and combined tokens
vocab := []string{"a", "b", "ab"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
// Test: "ab" "a" should be valid (ab as first item, a as second)
t.Run("ab then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("ab") {
t.Error("should accept ab")
}
if !engine.AcceptString("a") {
t.Error("should accept a after ab")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then ab", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept a")
}
if !engine.AcceptString("ab") {
t.Error("should accept ab after a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept first a")
}
if !engine.AcceptString("a") {
t.Error("should accept second a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
}

View File

@@ -1,614 +0,0 @@
//go:build mlx
// Package grammar provides GPU-accelerated constrained decoding using MLX.
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
// For JSON Schema conversion, see the grammar/schema subpackage.
package grammar
import (
"encoding/binary"
"fmt"
"io"
"strings"
"golang.org/x/exp/ebnf"
)
// stackSymbol represents a symbol that can be pushed onto the pda stack.
type stackSymbol int
const (
stackEmpty stackSymbol = iota
// Additional stack symbols will be generated per-grammar
)
// state represents a pda state.
type state int
const (
stateError state = -1
stateStart state = 0
stateAccept state = 1
// Additional states will be generated per-grammar
)
// transition represents a pda transition.
// On input matching Pattern, from FromState with stackTop:
// - Move to ToState
// - Pop StackPop symbols, push StackPush symbols
type transition struct {
FromState state
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
Pattern string // Input pattern to match (token or character class)
ToState state
StackPop int // Number of symbols to pop
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
}
// pda represents a compiled pushdown automaton.
type pda struct {
States int // Total number of states
StackSymbols int // Total number of stack symbols
StartState state // Initial state
AcceptStates map[state]bool // Set of accepting states
Transitions map[state][]transition // Transitions indexed by from-state
// For token-level matching
Terminals []string // All terminal symbols (patterns to match)
}
// newPDA creates an empty pda.
func newPDA() *pda {
return &pda{
States: 2, // Error and Start
StackSymbols: 1, // Empty
StartState: stateStart,
AcceptStates: make(map[state]bool),
Transitions: make(map[state][]transition),
Terminals: make([]string, 0),
}
}
// addState adds a new state and returns its ID.
func (p *pda) addState() state {
s := state(p.States)
p.States++
return s
}
// addStackSymbol adds a new stack symbol and returns its ID.
func (p *pda) addStackSymbol() stackSymbol {
s := stackSymbol(p.StackSymbols)
p.StackSymbols++
return s
}
// addTransition adds a transition to the pda.
func (p *pda) addTransition(t transition) {
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
}
// addTerminal registers a terminal pattern and returns its index.
func (p *pda) addTerminal(pattern string) int {
for i, t := range p.Terminals {
if t == pattern {
return i
}
}
p.Terminals = append(p.Terminals, pattern)
return len(p.Terminals) - 1
}
// compiler compiles EBNF grammars to PDAs.
type compiler struct {
grammar ebnf.Grammar
pda *pda
// Maps production names to their entry/exit states
prodEntry map[string]state
prodExit map[string]state
}
// compile parses an EBNF grammar and compiles it to a pda.
func compile(name string, src io.Reader, start string) (*pda, error) {
grammar, err := ebnf.Parse(name, src)
if err != nil {
return nil, fmt.Errorf("parse grammar: %w", err)
}
if err := ebnf.Verify(grammar, start); err != nil {
return nil, fmt.Errorf("verify grammar: %w", err)
}
c := &compiler{
grammar: grammar,
pda: newPDA(),
prodEntry: make(map[string]state),
prodExit: make(map[string]state),
}
// Create entry/exit states for each production
for name := range grammar {
c.prodEntry[name] = c.pda.addState()
c.prodExit[name] = c.pda.addState()
}
// compile each production
for name, prod := range grammar {
if err := c.compileProduction(name, prod); err != nil {
return nil, fmt.Errorf("compile production %q: %w", name, err)
}
}
// Set start state to entry of start production
if entry, ok := c.prodEntry[start]; ok {
// Add epsilon transition from pda start to grammar start
c.pda.addTransition(transition{
FromState: stateStart,
Pattern: "", // epsilon
ToState: entry,
})
} else {
return nil, fmt.Errorf("start production %q not found", start)
}
// Mark exit of start production as accepting
if exit, ok := c.prodExit[start]; ok {
c.pda.AcceptStates[exit] = true
}
return c.pda, nil
}
// compileString is a convenience function to compile from a string.
func compileString(grammar string, start string) (*pda, error) {
return compile("grammar", strings.NewReader(grammar), start)
}
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
entry := c.prodEntry[name]
exit := c.prodExit[name]
return c.compileExpr(prod.Expr, entry, exit)
}
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
switch e := expr.(type) {
case *ebnf.Name:
return c.compileName(e, entry, exit)
case *ebnf.Token:
return c.compileToken(e, entry, exit)
case ebnf.Sequence:
return c.compileSequence(e, entry, exit)
case ebnf.Alternative:
return c.compileAlternative(e, entry, exit)
case *ebnf.Option:
return c.compileOption(e, entry, exit)
case *ebnf.Repetition:
return c.compileRepetition(e, entry, exit)
case *ebnf.Group:
return c.compileExpr(e.Body, entry, exit)
case *ebnf.Range:
return c.compileRange(e, entry, exit)
case nil:
// Empty production - direct epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
default:
return fmt.Errorf("unsupported expression type: %T", expr)
}
}
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
// Reference to another production
prodName := n.String
prodEntry, ok := c.prodEntry[prodName]
if !ok {
return fmt.Errorf("undefined production: %s", prodName)
}
prodExit := c.prodExit[prodName]
// Use a unique stack symbol per call site so returns are unambiguous.
stackSym := c.pda.addStackSymbol()
// Push return address, go to production entry
c.pda.addTransition(transition{
FromState: entry,
Pattern: "", // epsilon
ToState: prodEntry,
StackPush: []stackSymbol{stackSym},
})
// On production exit, pop and return
c.pda.addTransition(transition{
FromState: prodExit,
stackTop: stackSym,
Pattern: "", // epsilon
ToState: exit,
StackPop: 1,
})
return nil
}
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
// terminal symbol - add transition that consumes this token
pattern := t.String
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
if len(seq) == 0 {
// Empty sequence - epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
}
// Chain: entry -> s1 -> s2 -> ... -> exit
current := entry
for i, expr := range seq {
var next state
if i == len(seq)-1 {
next = exit
} else {
next = c.pda.addState()
}
if err := c.compileExpr(expr, current, next); err != nil {
return err
}
current = next
}
return nil
}
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
// Each alternative goes from entry to exit
for _, expr := range alt {
if err := c.compileExpr(expr, entry, exit); err != nil {
return err
}
}
return nil
}
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
// Optional: can skip (epsilon) or take the body
// Epsilon transition (skip)
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Or take the body
return c.compileExpr(opt.Body, entry, exit)
}
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
// Repetition {body}: zero or more
// entry -> exit (skip)
// entry -> body -> entry (loop back)
// Skip transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Loop: entry -> (body) -> entry
return c.compileExpr(rep.Body, entry, entry)
}
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
begin := strings.Trim(r.Begin.String, "\"")
end := strings.Trim(r.End.String, "\"")
// Unescape bounds first (so "\u03b1" works)
beginUnesc, err := unescapeLiteral(begin)
if err != nil {
return fmt.Errorf("invalid range begin: %w", err)
}
endUnesc, err := unescapeLiteral(end)
if err != nil {
return fmt.Errorf("invalid range end: %w", err)
}
// Validate as single runes (not bytes) for Unicode support
beginRunes := []rune(beginUnesc)
endRunes := []rune(endUnesc)
if len(beginRunes) != 1 || len(endRunes) != 1 {
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
}
// Use unescaped rune strings in pattern (consistent with matcher)
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
// runtime represents a pda execution instance.
type runtime struct {
pda *pda
state state
stack []stackSymbol
}
// newRuntime creates a new pda runtime.
func newRuntime(pda *pda) *runtime {
return &runtime{
pda: pda,
state: pda.StartState,
stack: make([]stackSymbol, 0, 32),
}
}
// stackTop returns the top of the stack, or stackEmpty if empty.
func (r *runtime) stackTop() stackSymbol {
if len(r.stack) == 0 {
return stackEmpty
}
return r.stack[len(r.stack)-1]
}
// isAccepting returns true if we can reach an accepting state via epsilon transitions
// with an empty stack.
func (r *runtime) isAccepting() bool {
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
}
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if r.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
// Avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
for _, t := range r.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
// Check stack constraint
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Simulate stack operations
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 && len(newStack) >= t.StackPop {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if r.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// Reset resets the runtime to initial state.
func (r *runtime) Reset() {
r.state = r.pda.StartState
r.stack = r.stack[:0]
}
// validInputs returns all valid input patterns from current state.
func (r *runtime) validInputs() []string {
var valid []string
seen := make(map[string]bool)
visited := make(map[stateStackKey]bool)
// Make a copy of the stack for simulation
simStack := make([]stackSymbol, len(r.stack))
copy(simStack, r.stack)
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
return valid
}
// stateStackKey is used to detect cycles in epsilon closure
type stateStackKey struct {
state state
stackSig string
}
func stackSignature(stack []stackSymbol) string {
if len(stack) == 0 {
return ""
}
buf := make([]byte, len(stack)*8)
for i, sym := range stack {
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
}
return string(buf)
}
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
// Get stack top for comparisons
stackTop := stackEmpty
if len(simStack) > 0 {
stackTop = simStack[len(simStack)-1]
}
// Check for cycles to avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
if visited[key] {
return
}
visited[key] = true
transitions := r.pda.Transitions[state]
for _, t := range transitions {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.Pattern == "" {
// Epsilon transition - simulate stack operations
newStack := make([]stackSymbol, len(simStack))
copy(newStack, simStack)
// Pop
if t.StackPop > 0 {
if len(newStack) < t.StackPop {
continue // Can't pop, skip this transition
}
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
} else {
// terminal - add if not seen
if !seen[t.Pattern] {
seen[t.Pattern] = true
*valid = append(*valid, t.Pattern)
}
}
}
}
// matchesPattern checks if input matches a pattern.
// Patterns can be:
// - Exact strings: "a", "{", "true"
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
func matchesPattern(input, pattern string) bool {
// Exact match
if input == pattern {
return true
}
// Check for character range pattern [X-Y]
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
if len(input) != 1 {
return false
}
ch := input[0]
low := pattern[1]
high := pattern[3]
return ch >= low && ch <= high
}
return false
}
// Accept tries to accept an input, returning true if successful.
func (r *runtime) Accept(input string) bool {
return r.accept(input, make(map[stateStackKey]bool))
}
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
if visited[key] {
return false
}
visited[key] = true
transitions := r.pda.Transitions[r.state]
// First, process any epsilon transitions to reach a state that can accept input
// This is a simplified version - full implementation would need epsilon closure
for _, t := range transitions {
if matchesPattern(input, t.Pattern) {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Apply transition
r.applyTransition(t)
return true
}
}
// Try epsilon transitions first
for _, t := range transitions {
if t.Pattern == "" {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Save state for backtracking
oldState := r.state
oldStack := make([]stackSymbol, len(r.stack))
copy(oldStack, r.stack)
r.applyTransition(t)
if r.accept(input, visited) {
return true
}
// Backtrack
r.state = oldState
r.stack = oldStack
}
}
return false
}
func (r *runtime) applyTransition(t transition) {
// Pop
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
r.stack = r.stack[:len(r.stack)-t.StackPop]
}
// Push
r.stack = append(r.stack, t.StackPush...)
// Move to new state
r.state = t.ToState
}

View File

@@ -1,540 +0,0 @@
//go:build mlx
package grammar
import (
"testing"
)
func TestCompileSimpleGrammar(t *testing.T) {
// Simple grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
if pda == nil {
t.Fatal("pda is nil")
}
// Should have terminals "a" and "b"
if len(pda.Terminals) != 2 {
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
}
// Test runtime
rt := newRuntime(pda)
// Should accept "a" then "b"
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be in accepting state")
}
}
func TestCompileAlternative(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Test accepting "a"
rt := newRuntime(pda)
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'a'")
}
// Test accepting "b"
rt.Reset()
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// Test rejecting "c"
rt.Reset()
if rt.Accept("c") {
t.Error("should not accept 'c'")
}
}
func TestCompileRepetition(t *testing.T) {
// Grammar: S = {"a"} .
grammar := `S = {"a"} .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Empty should be accepted (zero repetitions)
rt := newRuntime(pda)
if !rt.isAccepting() {
t.Error("empty should be accepting")
}
// "a" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept first 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after one 'a'")
}
// "aa" should be accepted
if !rt.Accept("a") {
t.Error("should accept second 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after two 'a's")
}
}
func TestCompileOption(t *testing.T) {
// Grammar: S = ["a"] "b" .
grammar := `S = ["a"] "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "b" alone should be accepted
rt := newRuntime(pda)
if !rt.Accept("b") {
t.Error("should accept 'b' alone")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// "ab" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b' after 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'ab'")
}
}
func TestCompileRecursive(t *testing.T) {
// Grammar with recursion: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "x" should be accepted
rt := newRuntime(pda)
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'x'")
}
// "(x)" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x' inside parens")
}
if !rt.Accept(")") {
t.Error("should accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x)'")
}
// "((x))" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept first '('")
}
if !rt.Accept("(") {
t.Error("should accept second '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.Accept(")") {
t.Error("should accept first ')'")
}
if !rt.Accept(")") {
t.Error("should accept second ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '((x))'")
}
}
func TestValidInputs(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
valid := rt.validInputs()
// Should have both "a" and "b" as valid
hasA, hasB := false, false
for _, v := range valid {
if v == "a" {
hasA = true
}
if v == "b" {
hasB = true
}
}
if !hasA {
t.Error("'a' should be valid input")
}
if !hasB {
t.Error("'b' should be valid input")
}
}
// TestValidInputsAfterAccept tests that validInputs returns correct values
// after accepting tokens, ensuring proper stack simulation.
func TestValidInputsAfterAccept(t *testing.T) {
// Grammar: S = "a" "b" "c" .
grammar := `S = "a" "b" "c" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "a" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "a" {
t.Errorf("initially expected only 'a', got %v", valid)
}
// After accepting "a", only "b" should be valid
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "b" {
t.Errorf("after 'a', expected only 'b', got %v", valid)
}
// After accepting "b", only "c" should be valid
if !rt.Accept("b") {
t.Fatal("failed to accept 'b'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "c" {
t.Errorf("after 'ab', expected only 'c', got %v", valid)
}
}
// TestValidInputsWithRepetitionInProduction tests the critical case where
// a repetition exists inside a called production. This requires proper
// stack simulation to determine when closing symbols are valid.
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
// Grammar similar to JSON:
// S = "(" items ")" .
// items = item { "," item } .
// item = "x" .
grammar := `
S = "(" items ")" .
items = item { "," item } .
item = "x" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "(" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "(" {
t.Errorf("initially expected only '(', got %v", valid)
}
// Accept "("
if !rt.Accept("(") {
t.Fatal("failed to accept '('")
}
// After "(", should be able to accept "x" (item)
valid = rt.validInputs()
hasX := false
for _, v := range valid {
if v == "x" {
hasX = true
}
}
if !hasX {
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
}
// Accept first item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept 'x'")
}
// After "(x", should be able to accept "," (more items) OR ")" (end)
valid = rt.validInputs()
hasComma, hasClose := false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
}
// Accept comma for another item
if !rt.Accept(",") {
t.Fatal("failed to accept ','")
}
// After "(x,", should only be able to accept "x" (next item)
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "x" {
t.Errorf("after '(x,', expected only 'x', got %v", valid)
}
// Accept second item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept second 'x'")
}
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
// This tests the stack simulation fix - we need to properly
// follow epsilon transitions through the production call stack.
valid = rt.validInputs()
hasComma, hasClose = false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
}
// Close with ")"
if !rt.Accept(")") {
t.Fatal("failed to accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x,x)'")
}
}
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
func TestValidInputsNestedCalls(t *testing.T) {
// Grammar: A = "start" B "end" . B = "middle" .
grammar := `
A = "start" B "end" .
B = "middle" .
`
pda, err := compileString(grammar, "A")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// After "start", should accept "middle" (from B)
rt.Accept("start")
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "middle" {
t.Errorf("after 'start', expected 'middle', got %v", valid)
}
// After "start middle", should accept "end"
rt.Accept("middle")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "end" {
t.Errorf("after 'start middle', expected 'end', got %v", valid)
}
}
func TestReturnAddressDisambiguation(t *testing.T) {
// Grammar where the same production is called from different contexts:
// S = A "x" | "c" A "y" .
// A = "a" .
grammar := `
S = A "x" | "c" A "y" .
A = "a" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
if !rt.Accept("c") {
t.Fatal("failed to accept 'c'")
}
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "y" {
t.Errorf("after 'ca', expected only 'y', got %v", valid)
}
rt.Reset()
rt.Accept("c")
rt.Accept("a")
if rt.Accept("x") {
t.Error("should not accept 'x' after 'ca'")
}
}
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
// which heavily exercise the stack simulation.
func TestValidInputsRecursiveWithStack(t *testing.T) {
// Grammar: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially: "(" or "x" should be valid
valid := rt.validInputs()
hasParen, hasX := false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("initially expected '(' and 'x', got %v", valid)
}
// After "(": "(" or "x" should be valid (nested S)
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
}
// After "((": "(" or "x" should still be valid
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
}
// After "((x": only ")" should be valid
rt.Accept("x")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x', expected only ')', got %v", valid)
}
// After "((x)": only ")" should be valid (closing outer)
rt.Accept(")")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x)', expected only ')', got %v", valid)
}
}
// TestRejectionAfterValid tests that invalid inputs are rejected
// at various points in the grammar.
func TestRejectionAfterValid(t *testing.T) {
// Grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// "b" should be rejected initially
if rt.Accept("b") {
t.Error("'b' should be rejected initially")
}
// Accept "a"
rt.Accept("a")
// "a" should be rejected after "a"
if rt.Accept("a") {
t.Error("'a' should be rejected after 'a'")
}
// "c" should be rejected (not in grammar)
if rt.Accept("c") {
t.Error("'c' should be rejected (not in grammar)")
}
}

View File

@@ -1,56 +0,0 @@
# Example Grammars
This directory contains example EBNF grammars for constrained decoding.
## Usage
```bash
go run -tags mlx ./x/imagegen/cmd/engine/ \
-model /path/to/model \
-prompt "Your prompt" \
-grammar x/grammar/grammars/json.ebnf \
-grammar-start value
```
## Available Grammars
| File | Start Rule | Description |
|------|------------|-------------|
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
| `identifier.ebnf` | `ident` | Programming language identifiers |
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
| `list.ebnf` | `list` | Comma-separated word list |
| `yesno.ebnf` | `response` | Simple yes/no responses |
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
| `email.ebnf` | `email` | Basic email addresses |
| `phone.ebnf` | `phone` | US phone numbers |
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
## Grammar Syntax
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
The grammars use EBNF notation:
- `=` defines a production rule
- `|` is alternation (or)
- `{ }` is repetition (zero or more)
- `[ ]` is optional (zero or one)
- `" "` is a literal string
- `…` is a character range (e.g., `"a" … "z"`)
- `.` ends a production
## Writing Custom Grammars
1. Define your grammar in a `.ebnf` file
2. Choose a start rule name
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
Example custom grammar for RGB colors:
```ebnf
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .
```

View File

@@ -1,7 +0,0 @@
expr = term { " OR " term } .
term = factor { " AND " factor } .
factor = "NOT " factor | atom | "(" expr ")" .
atom = "true" | "false" | ident .
ident = letter { letter | digit } .
letter = "a" "z" | "A" "Z" .
digit = "0" "9" .

View File

@@ -1,6 +0,0 @@
date = year "-" month "-" day .
year = digit digit digit digit .
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -1,5 +0,0 @@
email = localpart "@" domain .
localpart = word { "." word } .
domain = word { "." word } .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .

View File

@@ -1,7 +0,0 @@
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = [ "-" ] digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -1,4 +0,0 @@
color = "#" ( hex6 | hex3 ) .
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hex3 = hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .

View File

@@ -1,3 +0,0 @@
ident = letter { letter | digit | "_" } .
letter = "a" "z" | "A" "Z" | "_" .
digit = "0" "9" .

View File

@@ -1,16 +0,0 @@
value = object | array | string | number | "true" | "false" | "null" .
object = "{" [ members ] "}" .
members = pair { "," pair } .
pair = string ":" value .
array = "[" [ elements ] "]" .
elements = value { "," value } .
string = "\"" { char } "\"" .
char = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
onenine = "1" "9" .
digit = "0" "9" .

View File

@@ -1,27 +0,0 @@
root = array .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -1,4 +0,0 @@
list = item { ", " item } .
item = word .
word = letter { letter } .
letter = "a" "z" | "A" "Z" .

View File

@@ -1,19 +0,0 @@
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -1,15 +0,0 @@
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
email_field = "\"email\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = " " | "!" | "#" "~" .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -1,7 +0,0 @@
phone = parenformat | dashformat .
parenformat = "(" areacode ") " exchange "-" subscriber .
dashformat = areacode "-" exchange "-" subscriber .
areacode = digit digit digit .
exchange = digit digit digit .
subscriber = digit digit digit digit .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -1,11 +0,0 @@
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
scheme = "http" | "https" .
host = word { "." word } .
port = digit { digit } .
path = "/" { pathseg } .
pathseg = word [ "/" ] .
query = "?" param { "&" param } .
param = word "=" word .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .
digit = "0" "9" .

View File

@@ -1,3 +0,0 @@
response = affirmative | negative .
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .

View File

@@ -1,69 +0,0 @@
//go:build mlx
package grammar
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
// Based on https://www.json.org/json-en.html
//
// This grammar operates at the character level. The engine validates
// tokens by matching them as sequences of these character-level terminals.
const JSONGrammarEBNF = `
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
const JSONObjectGrammarEBNF = `
json = object .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`

View File

@@ -1,726 +0,0 @@
//go:build mlx
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
package schema
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
"github.com/ollama/ollama/x/grammar"
)
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
// See: https://platform.openai.com/docs/guides/structured-outputs
type schemaNode struct {
// Core types
Type interface{} `json:"type"` // string, []string, or nil
// Object properties
Properties map[string]*schemaNode `json:"properties"`
Required []string `json:"required"`
AdditionalProperties interface{} `json:"additionalProperties"`
// Array properties
Items *schemaNode `json:"items"`
MinItems *int `json:"minItems"`
MaxItems *int `json:"maxItems"`
// String properties
Pattern string `json:"pattern"` // Regex pattern
Format string `json:"format"` // date-time, email, uuid, etc.
// Number properties (noted but not enforced in grammar - validated post-generation)
Minimum *float64 `json:"minimum"`
Maximum *float64 `json:"maximum"`
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
MultipleOf *float64 `json:"multipleOf"`
// Enum and const
Enum []interface{} `json:"enum"`
Const interface{} `json:"const"`
// Composition
AnyOf []*schemaNode `json:"anyOf"`
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
// References and definitions
Ref string `json:"$ref"`
Defs map[string]*schemaNode `json:"$defs"`
// Description (ignored for grammar but useful for docs)
Description string `json:"description"`
}
// converter handles JSON Schema to EBNF conversion with state.
type converter struct {
schema *schemaNode
definitions map[string]*schemaNode // Resolved $defs
usedTypes map[string]bool
rules []string
ruleNum int
definedRefs map[string]bool // Track which refs we've already defined as rules
}
// EBNF converts a JSON Schema to EBNF grammar
func EBNF(schemaJSON string) (string, error) {
var schema schemaNode
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
}
conv := &converter{
schema: &schema,
definitions: schema.Defs,
usedTypes: make(map[string]bool),
definedRefs: make(map[string]bool),
}
return conv.convert()
}
func (c *converter) convert() (string, error) {
var b strings.Builder
// Generate root rule
rootExpr := c.schemaToExpr(c.schema, "root")
b.WriteString("root = ")
b.WriteString(rootExpr)
b.WriteString(" .\n")
// Add generated rules (refs, items, etc.)
for _, rule := range c.rules {
b.WriteString(rule)
b.WriteString("\n")
}
// Add primitives based on usage
c.addPrimitives(&b)
return b.String(), nil
}
func (c *converter) addPrimitives(b *strings.Builder) {
if c.usedTypes["string"] {
b.WriteString(`
string = "\"" { character } "\"" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] {
b.WriteString(`
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
`)
}
if c.usedTypes["number"] {
b.WriteString(`
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
`)
}
if c.usedTypes["integer"] {
b.WriteString(`
int = [ "-" ] ( "0" | onenine { digit } ) .
`)
}
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
b.WriteString(`
digit = "0" … "9" .
`)
}
// onenine only needed for number/integer, not for digit-only formats
if c.usedTypes["number"] || c.usedTypes["integer"] {
b.WriteString(`onenine = "1" … "9" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
b.WriteString(`
hex = "0" … "9" | "A" … "F" | "a" … "f" .
`)
}
if c.usedTypes["ws"] {
b.WriteString(`
ws = { " " | "\t" | "\n" | "\r" } .
`)
}
}
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
if schema == nil {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
}
// Handle $ref first
if schema.Ref != "" {
return c.resolveRef(schema.Ref)
}
// Handle const
if schema.Const != nil {
return c.constToExpr(schema.Const)
}
// Handle enum
if len(schema.Enum) > 0 {
return c.enumToExpr(schema.Enum)
}
// Handle anyOf / oneOf
if len(schema.AnyOf) > 0 {
return c.anyOfToExpr(schema.AnyOf, name)
}
if len(schema.OneOf) > 0 {
return c.anyOfToExpr(schema.OneOf, name)
}
// Handle type
types := c.getTypes(schema.Type)
if len(types) == 0 {
// No type specified, could be anything
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | \"true\" | \"false\" | \"null\" )"
}
if len(types) == 1 {
return c.typeToExpr(types[0], schema, name)
}
// Multiple types (e.g., ["string", "null"])
var parts []string
for _, t := range types {
parts = append(parts, c.typeToExpr(t, schema, name))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
switch typeName {
case "object":
return c.objectToExpr(schema, name)
case "array":
return c.arrayToExpr(schema, name)
case "string":
return c.stringToExpr(schema, name)
case "number":
c.usedTypes["number"] = true
return "number"
case "integer":
c.usedTypes["integer"] = true
c.usedTypes["digit"] = true
return "int"
case "boolean":
return `( "true" | "false" )`
case "null":
return `"null"`
default:
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "string"
}
}
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
if len(schema.Properties) == 0 {
return `"{" ws "}"`
}
// Sort properties for deterministic output
// Required properties come first, in their required order
var propOrder []string
requiredSet := make(map[string]bool)
for _, r := range schema.Required {
requiredSet[r] = true
propOrder = append(propOrder, r)
}
// Add any non-required properties (though OpenAI requires all to be required)
var optionalProps []string
for propName := range schema.Properties {
if !requiredSet[propName] {
optionalProps = append(optionalProps, propName)
}
}
sort.Strings(optionalProps)
propOrder = append(propOrder, optionalProps...)
var propExprs []string
first := true
for _, propName := range propOrder {
propSchema, exists := schema.Properties[propName]
if !exists {
continue
}
propExpr := c.schemaToExpr(propSchema, propName)
prefix := ""
if !first {
prefix = `"," ws `
}
first = false
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
}
if len(propExprs) == 0 {
return `"{" ws "}"`
}
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
}
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
itemExpr := "value"
if schema.Items != nil {
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
} else {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
}
// Create item rule
c.ruleNum++
itemRule := fmt.Sprintf("item%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
// Handle minItems/maxItems
if schema.MinItems != nil || schema.MaxItems != nil {
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
}
// Default: zero or more items
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
min := 0
max := -1 // unlimited
if minItems != nil {
min = *minItems
}
if maxItems != nil {
max = *maxItems
}
if min == 0 && max < 0 {
// No constraints
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
if min == 0 && max == 0 {
return `"[" ws "]"`
}
// Build pattern for bounded array
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
var parts []string
// Required items
for i := 0; i < min; i++ {
if i > 0 {
parts = append(parts, `"," ws`)
}
parts = append(parts, itemRule)
}
// Optional items up to max
if max > min {
for i := min; i < max; i++ {
if i == 0 {
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
}
}
// Close all optional brackets
for i := min; i < max; i++ {
parts = append(parts, "]")
}
} else if max < 0 {
// Unlimited after min
if min > 0 {
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
}
}
if min == 0 {
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
}
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
}
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
// Handle format
if schema.Format != "" {
return c.formatToExpr(schema.Format)
}
// Handle pattern (regex)
if schema.Pattern != "" {
return c.patternToExpr(schema.Pattern, name)
}
// Default string
c.usedTypes["string"] = true
if name == "root" {
c.usedTypes["character"] = true
return `"\"" { character } "\""`
}
return "string"
}
func (c *converter) formatToExpr(format string) string {
switch format {
case "date":
// YYYY-MM-DD
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("date%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
return ruleName
case "time":
// HH:MM:SS
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("time%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
return ruleName
case "date-time":
// YYYY-MM-DDTHH:MM:SSZ or with offset
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
return ruleName
case "email":
// Simplified email pattern
c.ruleNum++
ruleName := fmt.Sprintf("email%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
return ruleName
case "uuid":
// 8-4-4-4-12 hex pattern
c.ruleNum++
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
c.usedTypes["hex"] = true
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
return ruleName
case "ipv4":
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
return ruleName
case "uri", "hostname":
// Fallback to general string for complex formats
c.usedTypes["string"] = true
return "string"
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) patternToExpr(pattern string, name string) string {
// Try to convert simple regex patterns to EBNF
// This handles common cases; complex regex falls back to string
// Remove anchors
pattern = strings.TrimPrefix(pattern, "^")
pattern = strings.TrimSuffix(pattern, "$")
// Try to parse and convert
expr, ok := c.regexToEBNF(pattern)
if !ok {
// Fallback to general string
c.usedTypes["string"] = true
return "string"
}
c.ruleNum++
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
return ruleName
}
func (c *converter) regexToEBNF(pattern string) (string, bool) {
// Simple regex to EBNF converter
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
var result strings.Builder
i := 0
for i < len(pattern) {
ch := pattern[i]
switch ch {
case '[':
// Character class
end := strings.Index(pattern[i:], "]")
if end == -1 {
return "", false
}
class := pattern[i+1 : i+end]
ebnfClass, ok := c.charClassToEBNF(class)
if !ok {
return "", false
}
result.WriteString(ebnfClass)
i += end + 1
case '(':
// Group - find matching )
depth := 1
start := i + 1
j := start
for j < len(pattern) && depth > 0 {
if pattern[j] == '(' {
depth++
} else if pattern[j] == ')' {
depth--
}
j++
}
if depth != 0 {
return "", false
}
groupContent := pattern[start : j-1]
groupExpr, ok := c.regexToEBNF(groupContent)
if !ok {
return "", false
}
result.WriteString("( ")
result.WriteString(groupExpr)
result.WriteString(" )")
i = j
case '|':
result.WriteString(" | ")
i++
case '+':
// One or more - wrap previous in { } and add one required
// This is a simplification
return "", false // TODO: handle properly
case '*':
// Zero or more - need to wrap previous
return "", false // TODO: handle properly
case '?':
// Optional - need to wrap previous in [ ]
return "", false // TODO: handle properly
case '\\':
// Escape sequence
if i+1 >= len(pattern) {
return "", false
}
next := pattern[i+1]
switch next {
case 'd':
result.WriteString("digit")
c.usedTypes["digit"] = true
case 'w':
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
case 's':
result.WriteString(`( " " | "\t" )`)
default:
result.WriteString(fmt.Sprintf(`"%c"`, next))
}
i += 2
default:
// Literal character
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
result.WriteString(fmt.Sprintf(`"%c" `, ch))
} else {
// Special char, try to escape
result.WriteString(fmt.Sprintf(`"%c" `, ch))
}
i++
}
}
return strings.TrimSpace(result.String()), true
}
func (c *converter) charClassToEBNF(class string) (string, bool) {
// Handle character classes like a-z, A-Z, 0-9
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
}
if class == "a-zA-Z0-9" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
}
if class == "a-z" {
return `"a" … "z"`, true
}
if class == "A-Z" {
return `"A" … "Z"`, true
}
if class == "0-9" {
c.usedTypes["digit"] = true
return "digit", true
}
// Try to parse range patterns
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
return "", false
}
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
var parts []string
for i, s := range schemas {
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
parts = append(parts, expr)
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) enumToExpr(values []interface{}) string {
var parts []string
for _, v := range values {
parts = append(parts, c.constToExpr(v))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) constToExpr(v interface{}) string {
switch val := v.(type) {
case string:
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
case float64:
if val == float64(int(val)) {
return fmt.Sprintf(`"%d"`, int(val))
}
return fmt.Sprintf(`"%v"`, val)
case bool:
if val {
return `"true"`
}
return `"false"`
case nil:
return `"null"`
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) resolveRef(ref string) string {
// Handle #/$defs/name references
if strings.HasPrefix(ref, "#/$defs/") {
defName := strings.TrimPrefix(ref, "#/$defs/")
return c.resolveDefRef(defName)
}
// Handle root recursion #
if ref == "#" {
return "root"
}
// Unknown ref format
c.usedTypes["string"] = true
return "string"
}
func (c *converter) resolveDefRef(defName string) string {
// Check if we've already defined this as a rule
ruleName := "def_" + defName
if c.definedRefs[defName] {
return ruleName
}
// Mark as defined to prevent infinite recursion
c.definedRefs[defName] = true
// Look up the definition
if c.definitions == nil {
c.usedTypes["string"] = true
return "string"
}
defSchema, ok := c.definitions[defName]
if !ok {
c.usedTypes["string"] = true
return "string"
}
// Generate the rule
expr := c.schemaToExpr(defSchema, ruleName)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
return ruleName
}
func (c *converter) getTypes(t interface{}) []string {
switch v := t.(type) {
case string:
return []string{v}
case []interface{}:
var types []string
for _, item := range v {
if s, ok := item.(string); ok {
types = append(types, s)
}
}
return types
}
return nil
}
func (c *converter) escapeString(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\"`)
return s
}
// Grammar converts a JSON Schema string into a compiled grammar.
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
ebnf, err := EBNF(schemaJSON)
if err != nil {
return nil, err
}
return grammar.ParseEBNF(ebnf, "root")
}

View File

@@ -1,336 +0,0 @@
//go:build mlx
package schema
import (
"testing"
gram "github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestJSONEBNF(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "simple object",
schema: `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`,
},
{
name: "with enum",
schema: `{
"type": "object",
"properties": {
"status": {"enum": ["active", "inactive", "pending"]}
},
"required": ["status"]
}`,
},
{
name: "array of objects",
schema: `{
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
}
}`,
},
{
name: "nested object",
schema: `{
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"email": {"type": "string"}
},
"required": ["email"]
}
},
"required": ["user"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
// Try to compile it
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}
func TestGrammarEngine(t *testing.T) {
schema := `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`
grammar, err := Grammar(schema)
if err != nil {
t.Fatalf("Grammar failed: %v", err)
}
vocab := []string{
"{", "}", "[", "]", ":", ",",
"\"name\"", "\"age\"", "\"test\"",
"\"", "a", "b", "c",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
" ", "\n",
"true", "false", "null",
}
engine, err := gram.NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("grammar.NewEngine failed: %v", err)
}
defer engine.Close()
logits := mlx.Ones(int32(len(vocab)))
mlx.Keep(logits)
// Test that we can apply mask
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
}
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
func TestOpenAIStructuredOutputs(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "anyOf union",
schema: `{
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "integer"}
]
}
},
"required": ["value"]
}`,
},
{
name: "nullable string via type array",
schema: `{
"type": "object",
"properties": {
"name": {"type": ["string", "null"]}
},
"required": ["name"]
}`,
},
{
name: "$ref with $defs",
schema: `{
"type": "object",
"properties": {
"person": {"$ref": "#/$defs/Person"}
},
"required": ["person"],
"$defs": {
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}
}
}`,
},
{
name: "const value",
schema: `{
"type": "object",
"properties": {
"type": {"const": "user"}
},
"required": ["type"]
}`,
},
{
name: "format date-time",
schema: `{
"type": "object",
"properties": {
"created": {"type": "string", "format": "date-time"}
},
"required": ["created"]
}`,
},
{
name: "format date",
schema: `{
"type": "object",
"properties": {
"birthday": {"type": "string", "format": "date"}
},
"required": ["birthday"]
}`,
},
{
name: "format email",
schema: `{
"type": "object",
"properties": {
"email": {"type": "string", "format": "email"}
},
"required": ["email"]
}`,
},
{
name: "format uuid",
schema: `{
"type": "object",
"properties": {
"id": {"type": "string", "format": "uuid"}
},
"required": ["id"]
}`,
},
{
name: "array with minItems maxItems",
schema: `{
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
"maxItems": 3
}
},
"required": ["tags"]
}`,
},
{
name: "deeply nested with refs",
schema: `{
"type": "object",
"properties": {
"company": {
"type": "object",
"properties": {
"name": {"type": "string"},
"employees": {
"type": "array",
"items": {"$ref": "#/$defs/Employee"}
}
},
"required": ["name", "employees"]
}
},
"required": ["company"],
"$defs": {
"Employee": {
"type": "object",
"properties": {
"name": {"type": "string"},
"role": {"enum": ["engineer", "manager", "intern"]}
},
"required": ["name", "role"]
}
}
}`,
},
{
name: "multiple refs same def",
schema: `{
"type": "object",
"properties": {
"from": {"$ref": "#/$defs/Address"},
"to": {"$ref": "#/$defs/Address"}
},
"required": ["from", "to"],
"$defs": {
"Address": {
"type": "object",
"properties": {
"city": {"type": "string"},
"zip": {"type": "string"}
},
"required": ["city", "zip"]
}
}
}`,
},
{
name: "oneOf variant",
schema: `{
"type": "object",
"properties": {
"result": {
"oneOf": [
{
"type": "object",
"properties": {"success": {"type": "boolean"}},
"required": ["success"]
},
{
"type": "object",
"properties": {"error": {"type": "string"}},
"required": ["error"]
}
]
}
},
"required": ["result"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}

View File

@@ -1,105 +0,0 @@
//go:build mlx
package grammar
import "unicode/utf8"
// terminalType distinguishes different kinds of grammar terminals
type terminalType int
const (
terminalLiteral terminalType = iota // Exact string: "true", "{"
terminalRange // Character range: [a-z], [0-9]
)
// terminal represents a compiled grammar terminal
type terminal struct {
ID int
Type terminalType
Pattern string // Original pattern from grammar
Unescaped string // Unescaped literal (for terminalLiteral)
LowRune rune // For unicode ranges: low bound
HighRune rune // For unicode ranges: high bound
}
// terminalMatch represents a terminal that matched at a position
type terminalMatch struct {
TerminalID int
Length int // Number of bytes consumed
}
// trieNode is a node in the literal matching trie
type trieNode struct {
children [256]*trieNode // Byte-indexed children
terminalID int // -1 if not accepting, else terminal ID
}
// terminalMatcher tests which terminals match at a position in a byte slice
type terminalMatcher struct {
// Trie for literal matching (fast path)
literalTrie *trieNode
// Range terminals (single-byte matches)
ranges []terminal
// All terminals for enumeration
terminals []terminal
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
patternToID map[string]int
}
// addLiteralToTrie adds a literal pattern to the trie
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
node := m.literalTrie
for i := 0; i < len(pattern); i++ {
c := pattern[i]
if node.children[c] == nil {
node.children[c] = &trieNode{terminalID: -1}
}
node = node.children[c]
}
node.terminalID = terminalID
}
// matchesAt returns all terminals that match at pos in data
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
if pos >= len(data) {
return nil
}
var matches []terminalMatch
// Check literal matches via trie
node := m.literalTrie
for i := pos; i < len(data) && node != nil; i++ {
c := data[i]
node = node.children[c]
if node != nil && node.terminalID >= 0 {
matches = append(matches, terminalMatch{
TerminalID: node.terminalID,
Length: i - pos + 1,
})
}
}
// Check range matches (unicode-aware)
r, runeLen := utf8.DecodeRune(data[pos:])
if r != utf8.RuneError {
for _, rng := range m.ranges {
if r >= rng.LowRune && r <= rng.HighRune {
matches = append(matches, terminalMatch{
TerminalID: rng.ID,
Length: runeLen,
})
}
}
}
return matches
}
// terminalCount returns the number of terminals
func (m *terminalMatcher) terminalCount() int {
return len(m.terminals)
}

View File

@@ -234,3 +234,17 @@ ollama create z-image
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,10 +1,8 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
@@ -101,10 +99,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imagePath string
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
@@ -118,14 +116,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
return
}
c.SSEvent("done", buildResponse(imagePath, format))
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imagePath string
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
@@ -133,7 +131,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
return
}
c.JSON(http.StatusOK, buildResponse(imagePath, format))
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
@@ -152,9 +150,9 @@ func parseSize(size string) (int32, int32) {
return int32(w), int32(h)
}
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
@@ -165,23 +163,21 @@ func parseProgress(content string) ImageProgressEvent {
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imagePath, format string) ImageGenerationResponse {
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imagePath == "" {
if imageBase64 == "" {
return resp
}
if format == "url" {
resp.Data[0].URL = "file://" + imagePath
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
resp.Data[0].B64JSON = imageBase64
}
return resp

197
x/imagegen/cache/teacache.go vendored Normal file
View File

@@ -0,0 +1,197 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -44,62 +44,64 @@ func DefaultOptions() ImageGenOptions {
}
}
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
architecture = index.Architecture
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
return fmt.Sprintf("%d", count)
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command.
@@ -121,11 +123,6 @@ func RegisterFlags(cmd *cobra.Command) {
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
@@ -183,8 +180,7 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -203,11 +199,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
@@ -218,9 +212,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
return nil
@@ -306,7 +318,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -326,11 +338,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
@@ -342,25 +352,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
@@ -381,24 +396,6 @@ func sanitizeFilename(s string) string {
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
@@ -509,10 +506,7 @@ func displayImageInTerminal(imagePath string) bool {
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
end := min(i+chunkSize, len(encoded))
chunk := encoded[i:end]
if i == 0 {

View File

@@ -29,9 +29,10 @@ const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
@@ -58,18 +59,77 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return imagegen.LayerInfo{}, err
return nil, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
@@ -119,7 +179,7 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err

View File

@@ -0,0 +1,120 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -0,0 +1,18 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -8,7 +8,6 @@ import (
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -110,11 +109,7 @@ type input struct {
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
JSONMode bool // Enable JSON grammar constraint
GrammarEBNF string // Raw EBNF grammar string
GrammarStart string // Start rule name for grammar
Vocab []string // Vocabulary for constrained decoding
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
@@ -132,11 +127,9 @@ type Decoder struct {
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
grammar *grammar.Engine // Optional grammar constraint engine
grammarVocab []string // Vocab for grammar debug
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
@@ -152,12 +145,6 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
}
}
// SetGrammar enables constrained decoding with the given grammar engine.
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
d.grammar = g
d.grammarVocab = vocab
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
@@ -235,16 +222,6 @@ func (d *Decoder) prefill(inputIDs []int32) int {
} else {
logits = d.model.Forward(x, d.caches)
}
// Apply grammar constraints if enabled
if d.grammar != nil {
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
@@ -268,15 +245,6 @@ func (d *Decoder) prefill(inputIDs []int32) int {
func (d *Decoder) step() int32 {
prevToken := d.token
// Sync on previous token FIRST to get its value and update grammar state
// This must happen before computing the next mask
val := prevToken.ItemInt32()
// Update grammar state with the token we just synced
if d.grammar != nil {
d.grammar.Accept(int(val))
}
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
@@ -285,18 +253,6 @@ func (d *Decoder) step() int32 {
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
// Apply grammar constraints if enabled
if d.grammar != nil {
// Get last position logits: [1, 1, vocab] -> [vocab]
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
// Reshape back to [1, 1, vocab] for sample()
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
@@ -306,6 +262,9 @@ func (d *Decoder) step() int32 {
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
@@ -330,48 +289,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Set up grammar constraint if enabled
var grammarEngine *grammar.Engine
var grammarVocab []string
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
var compiled *grammar.Grammar
var err error
if in.GrammarEBNF != "" {
// Custom EBNF grammar
startRule := in.GrammarStart
if startRule == "" {
startRule = "root"
}
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
if err != nil {
return fmt.Errorf("failed to parse grammar: %w", err)
}
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
} else {
// JSON object grammar (only allows objects at top level)
compiled, err = grammar.JSONObjectGrammar()
if err != nil {
return fmt.Errorf("failed to create JSON grammar: %w", err)
}
fmt.Println("[JSON object mode enabled]")
}
// Pad vocab to match model's vocab size if needed
grammarVocab = in.Vocab
modelVocabSize := int(m.VocabSize())
if len(grammarVocab) < modelVocabSize {
padded := make([]string, modelVocabSize)
copy(padded, grammarVocab)
grammarVocab = padded
}
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
if err != nil {
return fmt.Errorf("failed to create grammar engine: %w", err)
}
defer grammarEngine.Close()
}
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
@@ -387,10 +304,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tokens = tok.Encode(prompt, true)
}
if grammarEngine != nil {
dec.SetGrammar(grammarEngine, grammarVocab)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
@@ -414,11 +327,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete after first token
if dec.grammar != nil && dec.grammar.IsComplete() {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
@@ -433,10 +341,6 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete (valid JSON document finished)
if dec.grammar != nil && dec.grammar.IsComplete() {
break
}
if n%256 == 0 {
mlx.ClearCache()

View File

@@ -11,9 +11,11 @@ import (
"os"
"path/filepath"
"runtime/pprof"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -44,9 +46,6 @@ func main() {
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
// Image generation params
width := flag.Int("width", 1024, "Image width")
@@ -64,12 +63,16 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
@@ -102,13 +105,44 @@ func main() {
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *glmImageFlag:
m := &glm_image.Model{}
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
var loadErr error
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
loadErr = m.LoadFromPath(*modelPath)
} else {
loadErr = m.Load(*modelPath)
}
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
Temperature: float32(*temperature),
TopP: float32(*topP),
GuidanceScale: float32(*cfgScale),
MaxVisualTokens: int32(*maxTokens),
})
if err == nil {
err = saveImageArray(img, *out)
@@ -189,20 +223,6 @@ func main() {
}
}
// Get vocab for constrained decoding if needed
var vocab []string
var grammarEBNF string
if *jsonMode || *grammarFile != "" {
vocab = m.Tokenizer().Vocab()
}
if *grammarFile != "" {
data, err := os.ReadFile(*grammarFile)
if err != nil {
log.Fatalf("failed to read grammar file: %v", err)
}
grammarEBNF = string(data)
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
@@ -211,10 +231,6 @@ func main() {
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
JSONMode: *jsonMode,
GrammarEBNF: grammarEBNF,
GrammarStart: *grammarStart,
Vocab: vocab,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)

View File

@@ -40,13 +40,15 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
@@ -74,7 +76,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
@@ -83,16 +89,30 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, layer)
layers = append(layers, newLayers...)
}
extractor.Close()
@@ -106,10 +126,13 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"vision_language_encoder/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"processor/tokenizer.json", // GLM-Image main tokenizer
"processor/tokenizer_config.json", // GLM-Image tokenizer config
}
for _, cfgPath := range configFiles {
@@ -122,7 +145,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
var r io.Reader
// For model_index.json, normalize to Ollama format
// For model_index.json, normalize to Ollama format and add metadata
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
@@ -141,6 +164,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)

View File

@@ -60,9 +60,12 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
mlx.Eval(img)
imgShape := img.Shape()

19
x/imagegen/imagegen.md Normal file
View File

@@ -0,0 +1,19 @@
# Image generation models (experimental)
Experimental image generation models are available for **macOS** in Ollama:
## Available models
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
```
ollama run x/z-image-turbo
```
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
More models coming soon:
1. Qwen-Image-2512
2. Qwen-Image-Edit-2511
3. GLM-Image

View File

@@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
}
// CheckPlatformSupport validates that image generation is supported on the current platform.

View File

@@ -607,6 +607,11 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
@@ -1480,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
return int32(val)
}
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============
// String returns a string representation
@@ -1658,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata)
}
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
@@ -1729,14 +1800,6 @@ func init() {
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()
// Avoid Metal device init crashes on systems without Metal.
if runtime.GOOS == "darwin" {
if MetalIsAvailable() {
SetDefaultDeviceGPU()
} else {
SetDefaultDeviceCPU()
}
}
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
Keep(RandomState[0]) // Global state should persist
}
@@ -1994,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
// mode: "affine" (default), "mxfp4", or "mxfp8"
// Note: mxfp8 mode returns nil biases (only weights and scales)
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -2003,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of 3 arrays: [weights, scales, biases]
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
C.mlx_vector_array_get(&w2, res, 2)
if vecSize >= 3 {
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_free(res)
return newArray(w0), newArray(w1), newArray(w2)
if vecSize >= 3 {
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
}
// Dequantize reconstructs weights from quantized form.

View File

@@ -0,0 +1,693 @@
//go:build mlx
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
package glm_image
import (
"context"
"fmt"
"math"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
// Special tokens: 0=pad, 1=eos, 2=unk
type ByT5Tokenizer struct {
PadTokenID int32
EOSTokenID int32
UNKTokenID int32
}
// NewByT5Tokenizer creates a new ByT5 tokenizer
func NewByT5Tokenizer() *ByT5Tokenizer {
return &ByT5Tokenizer{
PadTokenID: 0,
EOSTokenID: 1,
UNKTokenID: 2,
}
}
// Encode converts a string to token IDs
func (t *ByT5Tokenizer) Encode(text string) []int32 {
bytes := []byte(text)
tokens := make([]int32, len(bytes))
for i, b := range bytes {
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
// (tokens 0, 1, 2 are PAD, EOS, UNK)
tokens[i] = int32(b) + 3
}
return tokens
}
// Decode converts token IDs back to a string
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
bytes := make([]byte, 0, len(tokens))
for _, tok := range tokens {
if tok >= 3 && tok < 259 {
bytes = append(bytes, byte(tok-3))
}
}
return string(bytes)
}
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
GuidanceScale float32 // Guidance scale (default: 1.5)
Width int32 // Image width (default: 1024, must be divisible by 32)
Height int32 // Image height (default: 1024, must be divisible by 32)
Steps int // Diffusion denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// AR generation options
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
Temperature float32 // AR sampling temperature (default: 0.9)
TopP float32 // Nucleus sampling (default: 0.75)
}
// ProgressFunc is called during generation with stage and step progress.
type ProgressFunc func(stage string, step, totalSteps int)
// Model represents a GLM-Image hybrid model.
type Model struct {
ModelName string
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
TextEncoder *T5TextEncoder
VisionLanguageEncoder *VisionLanguageEncoder
Transformer *DiffusionTransformer
VAEDecoder *VAEDecoder
}
// Load loads the GLM-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
// Used for T5 text encoder (glyph embeddings)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizer(manifest)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder (~830MB)
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder (~19GB, 9B params)
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer (~13GB, 7B params)
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder (~775MB)
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// LoadFromPath loads the model from a directory path (not ollama manifest)
func (m *Model) LoadFromPath(modelPath string) error {
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelPath
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizerFromPath(modelPath)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal generation pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.5
}
// Calculate MaxVisualTokens based on image dimensions
// GLM-Image generates TWO grids of visual tokens:
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
// 2. Then: target (large) grid - tokenH × tokenW tokens
// After generation, we extract only the TARGET grid tokens for diffusion.
factor := int32(32)
tokenH := cfg.Height / factor
tokenW := cfg.Width / factor
targetGridTokens := tokenH * tokenW
// Compute prev grid dimensions using diffusers formula:
// ratio = token_h / token_w
// prev_token_h = int(sqrt(ratio) * 16)
// prev_token_w = int(sqrt(1/ratio) * 16)
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
prevGridTokens := prevTokenH * prevTokenW
// Total tokens to generate = prev grid + target grid
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
if cfg.Temperature <= 0 {
cfg.Temperature = 0.9
}
if cfg.TopP <= 0 {
cfg.TopP = 0.75
}
// Ensure dimensions are divisible by 32
cfg.Width = (cfg.Width / 32) * 32
cfg.Height = (cfg.Height / 32) * 32
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
// Progress callback helper
progress := func(stage string, step, total int) {
if cfg.Progress != nil {
cfg.Progress(stage, step, total)
}
}
// === PHASE 1: T5 Text Encoding ===
fmt.Println("[T5] Encoding glyph text...")
progress("text_encoding", 0, 1)
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
mlx.Keep(textEmbed)
mlx.Eval(textEmbed)
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
progress("text_encoding", 1, 1)
// === PHASE 2: AR Visual Token Generation ===
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
visualTokens := m.VisionLanguageEncoder.Generate(
cfg.Prompt,
m.GLMTokenizer,
cfg.MaxVisualTokens,
cfg.Temperature,
cfg.TopP,
cfg.Seed,
cfg.Height,
cfg.Width,
func(step int) {
if step%100 == 0 || step < 10 {
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
}
progress("ar_generation", step, int(cfg.MaxVisualTokens))
},
)
mlx.Keep(visualTokens)
mlx.Eval(visualTokens)
fmt.Printf("[AR] Done generating visual tokens\n")
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
vtShape := visualTokens.Shape()
totalGenerated := vtShape[1]
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
// Extract only the TARGET grid tokens (skip the prev grid tokens)
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
// large_image_start_offset = prev_grid_size
var targetGridVisualTokens *mlx.Array
if totalGenerated >= prevGridTokens+targetGridTokens {
// Full generation completed - extract target grid
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, prevGridTokens + targetGridTokens})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
} else if totalGenerated > prevGridTokens {
// Partial target grid - take what we have
actualTargetTokens := totalGenerated - prevGridTokens
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, totalGenerated})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
actualTargetTokens, targetGridTokens)
} else {
// Not enough tokens - EOS came too early
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
totalGenerated, prevGridTokens)
}
// === PHASE 3: Diffusion Decoding ===
// Setup scheduler with dynamic shift based on image size
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
// Initialize noise latents [B, C, H, W]
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
// target_grid tokens -> 2x upsample -> patch_count
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
mlx.Keep(priorEmbed)
mlx.Eval(priorEmbed)
// Prepare text conditioning (project T5 embeddings)
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
mlx.Keep(textCond)
mlx.Eval(textCond)
// === CFG Setup ===
// For classifier-free guidance, we need unconditional (negative) text embeddings
// GLM-Image uses empty string "" for negative prompt
doCFG := cfg.GuidanceScale > 1.0
var negativeTextCond *mlx.Array
if doCFG {
// Encode empty string for negative prompt
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
mlx.Keep(negativeTextEmbed)
mlx.Eval(negativeTextEmbed)
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
mlx.Keep(negativeTextCond)
mlx.Eval(negativeTextCond)
negativeTextEmbed.Free()
}
// Prepare conditioning inputs
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
targetSize = mlx.ToBFloat16(targetSize)
cropCoords = mlx.ToBFloat16(cropCoords)
mlx.Keep(targetSize)
mlx.Keep(cropCoords)
mlx.Eval(targetSize, cropCoords)
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
// Denoising loop
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
progress("diffusion", 0, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
latents.Free()
return nil, ctx.Err()
default:
}
}
// Get timestep value for the transformer
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
timestepVal := scheduler.Timesteps[i] - 1
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
patches := PatchifyLatents(latents, tcfg.PatchSize)
// Transformer forward with MMDiT architecture
// Conditional pass (with text + prior embeddings)
outputCond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed,
textCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
false, // priorTokenDrop = false for conditional
)
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
var noisePred *mlx.Array
if doCFG {
// Unconditional pass (empty text, dropped prior embeddings)
outputUncond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
negativeTextCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
true, // priorTokenDrop = true for unconditional
)
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
diff := mlx.Sub(noisePredCond, noisePredUncond)
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
noisePred = mlx.Add(noisePredUncond, scaled)
} else {
noisePred = noisePredCond
}
// Scheduler step
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
progress("diffusion", i+1, cfg.Steps)
}
// Cleanup intermediate arrays
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
if negativeTextCond != nil {
negativeTextCond.Free()
}
targetSize.Free()
cropCoords.Free()
// === PHASE 4: VAE Decode ===
progress("vae_decode", 0, 1)
decoded := m.VAEDecoder.Decode(latents)
mlx.Eval(decoded)
latents.Free()
progress("vae_decode", 1, 1)
return decoded, nil
}
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
// scale must be 2 or 4
//
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
// missing tokens are padded with 0 (visual token padding value).
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
// Each token at (i, j) becomes scale*scale tokens in the output
mlx.Eval(tokens)
tokenData := tokens.DataInt32()
numTokens := int32(len(tokenData))
expectedTokens := prevH * prevW
// Warn if we got fewer tokens than expected (early EOS)
if numTokens < expectedTokens {
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
numTokens, expectedTokens)
}
targetH := prevH * scale
targetW := prevW * scale
upsampled := make([]int32, targetH*targetW)
for i := int32(0); i < prevH; i++ {
for j := int32(0); j < prevW; j++ {
srcIdx := i*prevW + j
// Handle early EOS: use 0 (padding) for missing tokens
var val int32
if srcIdx < numTokens {
val = tokenData[srcIdx]
} else {
val = 0 // Padding token
}
// Place in scale*scale positions
dstI := i * scale
dstJ := j * scale
for di := int32(0); di < scale; di++ {
for dj := int32(0); dj < scale; dj++ {
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
}
}
}
}
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
}
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// Transpose: -> [B, pH, pW, C, p, p]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// Flatten: -> [B, pH*pW, C*p*p]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// Transpose: -> [B, C, pH, p, pW, p]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// Reshape: -> [B, C, H, W]
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
}
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
func CalculateShift(imgSeqLen int32) float32 {
cfg := DefaultSchedulerConfig()
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
return m*cfg.MaxShift + cfg.BaseShift
}
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
// This matches diffusers' _upsample_token_ids function
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
shape := tokens.Shape()
B := shape[0]
// Reshape to [B, 1, H, W] for interpolation
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
// Convert to float for interpolation
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
// 2x nearest neighbor upsample
// [B, 1, H, W] -> [B, 1, H*2, W*2]
upsampled := nearestUpsample2x(tokensFloat)
// Convert back to int and reshape to [B, H*2*W*2]
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
}
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Repeat each element 2x2
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Tile to repeat each pixel 2x2
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
// Reshape to final size
return mlx.Reshape(x, B, C, H*2, W*2)
}

View File

@@ -0,0 +1,358 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen"
)
// GLMTokenizer implements the GLM tokenizer for the AR model
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
// greedy longest-match tokenization from the vocab without runtime merging.
type GLMTokenizer struct {
Vocab map[string]int32 // token string -> token ID
VocabReverse map[int32]string // token ID -> token string
SpecialTokens map[string]int32 // special token strings -> IDs
// Special token IDs
SopTokenID int32 // <sop> = grid_bos_token (167845)
EopTokenID int32 // <eop> = grid_eos_token (167846)
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
PadTokenID int32
// Sorted vocab keys by length (longest first) for greedy matching
sortedTokens []string
}
// tokenizerJSON represents the structure of tokenizer.json
type tokenizerJSON struct {
Model struct {
Vocab map[string]int32 `json:"vocab"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory in manifest
data, err := manifest.ReadConfig("processor/tokenizer.json")
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
data, err := os.ReadFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// Encode tokenizes a string into token IDs
// This uses greedy longest-match tokenization with GPT-2 style space handling
func (t *GLMTokenizer) Encode(text string) []int32 {
if text == "" {
return []int32{}
}
var tokens []int32
// First, check for and handle special tokens
// Replace special tokens with placeholders, encode, then restore
specialReplacements := make(map[string]int32)
for special, id := range t.SpecialTokens {
if strings.Contains(text, special) {
specialReplacements[special] = id
}
}
// Process text character by character with special token handling
i := 0
isFirstToken := true
for i < len(text) {
// Check for special tokens first
foundSpecial := false
for special, id := range specialReplacements {
if strings.HasPrefix(text[i:], special) {
tokens = append(tokens, id)
i += len(special)
isFirstToken = false
foundSpecial = true
break
}
}
if foundSpecial {
continue
}
// Handle regular text with GPT-2 style space prefix
// "Ġ" (U+0120) represents a space before a token
remaining := text[i:]
// Try to find the longest matching token
matched := false
for _, token := range t.sortedTokens {
// Skip special tokens in regular matching
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
continue
}
// Check if this token matches
tokenText := token
// Handle the Ġ prefix (represents space)
if strings.HasPrefix(token, "Ġ") {
// This token expects a leading space
if i > 0 || !isFirstToken {
// Check if remaining starts with space + token content
tokenContent := token[len("Ġ"):]
if strings.HasPrefix(remaining, " "+tokenContent) {
tokens = append(tokens, t.Vocab[token])
i += 1 + len(tokenContent) // space + content
isFirstToken = false
matched = true
break
}
}
} else {
// Regular token without space prefix
if strings.HasPrefix(remaining, tokenText) {
tokens = append(tokens, t.Vocab[token])
i += len(tokenText)
isFirstToken = false
matched = true
break
}
}
}
if !matched {
// No token found - skip this character (or use UNK)
// For now, just skip unknown characters
i++
}
}
return tokens
}
// EncodeForGeneration encodes a prompt with grid tokens for image generation
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
//
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
// space prefix), matching the HuggingFace tokenizer behavior.
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
// Calculate grid dimensions
factor := int32(32)
height := (targetHeight / factor) * factor
width := (targetWidth / factor) * factor
tokenH := height / factor
tokenW := width / factor
// Calculate previous grid dimensions
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(sqrt(ratio) * 16)
prevTokenW := int32(sqrt(1.0/ratio) * 16)
// Encode the prompt text
promptTokens := t.Encode(prompt)
// Build the full sequence:
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
var tokens []int32
tokens = append(tokens, promptTokens...)
// First grid: <sop> H W <eop>
// First number has no space prefix, second number has space prefix (Ġ)
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(tokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
tokens = append(tokens, t.EopTokenID)
// Second grid: <sop> prevH prevW <eop>
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
tokens = append(tokens, t.EopTokenID)
// BOS token (start of image generation)
tokens = append(tokens, t.BosTokenID)
return tokens
}
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up the whole number as a single token
if id, ok := t.Vocab[s]; ok {
return []int32{id}
}
// Fallback: encode digit by digit
var tokens []int32
for _, c := range s {
if id, ok := t.Vocab[string(c)]; ok {
tokens = append(tokens, id)
}
}
return tokens
}
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
spaceToken := "Ġ" + s
if id, ok := t.Vocab[spaceToken]; ok {
return []int32{id}
}
// Fallback: bare space Ġ + number tokens
var tokens []int32
if spaceID, ok := t.Vocab["Ġ"]; ok {
tokens = append(tokens, spaceID)
}
tokens = append(tokens, t.encodeNumber(n)...)
return tokens
}
// sqrt is a helper for float64 sqrt
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
// Newton's method
z := x
for i := 0; i < 10; i++ {
z = z - (z*z-x)/(2*z)
}
return z
}
// Decode converts token IDs back to a string
func (t *GLMTokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if token, ok := t.VocabReverse[id]; ok {
// Handle Ġ prefix (convert back to space)
if strings.HasPrefix(token, "Ġ") {
sb.WriteString(" ")
sb.WriteString(token[len("Ġ"):])
} else {
sb.WriteString(token)
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,159 @@
//go:build mlx
package glm_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.25
MaxShift float32 `json:"max_shift"` // 0.75
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "linear"
}
// DefaultSchedulerConfig returns the default config for GLM-Image
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.25,
MaxShift: 0.75,
BaseImageSeqLen: 256,
MaxImageSeqLen: 4096,
UseDynamicShifting: true,
TimeShiftType: "linear",
}
}
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
type FlowMatchScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
Sigmas []float32 // Shifted sigmas for Euler step calculation
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{Config: cfg}
}
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
s.NumSteps = numSteps
// Calculate shift (mu) based on image sequence length
mu := s.calculateShift(imgSeqLen)
// Create timesteps: linspace from sigma_max_t to sigma_min_t
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
// Then apply time shift and append terminal sigma=0
s.Timesteps = make([]float32, numSteps)
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
for i := 0; i < numSteps; i++ {
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
s.Timesteps[i] = tRaw
// Convert to sigma [0, 1]
sigma := tRaw / numTrainTimesteps
// Apply time shift if enabled
if s.Config.UseDynamicShifting && mu > 0 {
sigma = s.applyShift(mu, sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal sigma = 0 (the final clean image)
s.Sigmas[numSteps] = 0
}
// calculateShift computes dynamic shift based on image sequence length
// Uses the sqrt-based formula from diffusers:
// m = (image_seq_len / base_seq_len) ** 0.5
// mu = m * max_shift + base_shift
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
cfg := s.Config
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
mu := m*cfg.MaxShift + cfg.BaseShift
return mu
}
// applyShift applies time shift transformation
// mu: the computed shift value
// t: sigma value in [0, 1]
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if t >= 1 {
return 1
}
// sigma=1.0 for both shift types
sigma := float32(1.0)
if s.Config.TimeShiftType == "linear" {
// Linear: mu / (mu + (1/t - 1)^sigma)
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
sigma := s.Sigmas[stepIdx]
sigmaNext := s.Sigmas[stepIdx+1]
// Euler step: x_{t-dt} = x_t + dt * v_t
dt := sigmaNext - sigma // Negative (going from noise to clean)
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// InitNoise creates initial noise
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// AddNoise adds noise to clean samples for a given timestep (for img2img)
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
// Use sigmas (shifted) for the interpolation
sigma := s.Sigmas[timestepIdx]
oneMinusSigma := 1.0 - sigma
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
scaledNoise := mlx.MulScalar(noise, sigma)
return mlx.Add(scaledClean, scaledNoise)
}
// GetTimesteps returns all timesteps
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
return s.Timesteps
}

View File

@@ -0,0 +1,497 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// T5Config holds T5 encoder configuration
type T5Config struct {
DModel int32 `json:"d_model"` // 1472
DFF int32 `json:"d_ff"` // 3584
DKV int32 `json:"d_kv"` // 64
NumHeads int32 `json:"num_heads"` // 6
NumLayers int32 `json:"num_layers"` // 12
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
// Relative position bias
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
}
// T5TextEncoder is the T5 encoder for text conditioning
type T5TextEncoder struct {
Config *T5Config
// Embedding (shared for ByT5)
SharedEmbed *nn.Embedding `weight:"shared"`
// Encoder layers
Layers []*T5Block `weight:"encoder.block"`
// Final layer norm
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
// Relative position bias (from first layer, shared across all)
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
}
// T5Block is a single T5 encoder block
type T5Block struct {
// Self attention
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
// FFN
Layer1 *T5LayerFF `weight:"layer.1"`
}
// T5LayerSelfAttention is T5's self-attention layer
type T5LayerSelfAttention struct {
SelfAttention *T5Attention `weight:"SelfAttention"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5Attention implements T5's relative attention
type T5Attention struct {
Q *mlx.Array `weight:"q.weight"` // No bias in T5
K *mlx.Array `weight:"k.weight"`
V *mlx.Array `weight:"v.weight"`
O *mlx.Array `weight:"o.weight"`
NHeads int32
DKV int32
Scale float32
}
// T5LayerFF is T5's feedforward layer with gated-gelu
type T5LayerFF struct {
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5DenseGatedGelu is T5's gated-gelu FFN
type T5DenseGatedGelu struct {
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
Wo *mlx.Array `weight:"wo.weight"` // down projection
}
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
type T5LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
// Load loads the T5 text encoder from manifest
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the T5 text encoder from a directory path
func (m *T5TextEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
func (m *T5TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.LayerNormEps
for _, block := range m.Layers {
attn := block.Layer0.SelfAttention
attn.NHeads = cfg.NumHeads
attn.DKV = cfg.DKV
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
}
}
// Forward encodes text tokens
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
cfg := m.Config
// Get embeddings
h := m.SharedEmbed.Forward(tokens)
// Compute relative position bias once
seqLen := tokens.Shape()[1]
posBias := m.computeRelativePositionBias(seqLen)
// Forward through layers
for _, block := range m.Layers {
h = block.Forward(h, posBias, cfg.LayerNormEps)
}
// Final norm
h = m.FinalNorm.Forward(h)
return h
}
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
// Glyph texts are used for text rendering guidance in the generated image
func extractGlyphTexts(prompt string) []string {
var glyphTexts []string
// Extract text in single quotes: 'text'
re1 := regexp.MustCompile(`'([^']*)'`)
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Unicode curly double quotes: "text"
re2 := regexp.MustCompile(`"([^""]*)"`)
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in ASCII double quotes: "text"
re3 := regexp.MustCompile(`"([^"]*)"`)
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Japanese quotes: 「text」
re4 := regexp.MustCompile(`「([^「」]*)」`)
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
return glyphTexts
}
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
// This provides text conditioning for the diffusion transformer via the glyph projector
//
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
// This matches diffusers' _get_glyph_embeds() behavior.
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
// Extract glyph texts from prompt (text in quotes)
glyphTexts := extractGlyphTexts(prompt)
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
if len(glyphTexts) == 0 {
glyphTexts = []string{""}
}
// Encode each glyph text and collect token sequences
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
var allTokenSeqs [][]int32
for _, glyphText := range glyphTexts {
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
tokens := tok.Encode(glyphText)
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
tokens = append(tokens, tok.EOSTokenID)
allTokenSeqs = append(allTokenSeqs, tokens)
}
// Process each glyph text through the encoder
var allEmbeddings []*mlx.Array
for _, tokens := range allTokenSeqs {
tokenLen := len(tokens)
if tokenLen == 0 {
continue
}
// Create token array [1, L]
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
// Forward through encoder
output := m.Forward(tokensArr)
mlx.Eval(output)
allEmbeddings = append(allEmbeddings, output)
}
// Concatenate all glyph embeddings along sequence dimension
var output *mlx.Array
if len(allEmbeddings) == 0 {
// Fallback: return single zero embedding
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
} else if len(allEmbeddings) == 1 {
output = allEmbeddings[0]
} else {
output = mlx.Concatenate(allEmbeddings, 1)
}
mlx.Eval(output)
return output
}
// computeRelativePositionBias computes T5's relative position encoding
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
cfg := m.Config
// Create relative position matrix
// For each (query_pos, key_pos) pair, compute bucketed relative position
numBuckets := cfg.RelativeAttentionNumBuckets
maxDistance := cfg.RelativeAttentionMaxDistance
// Create position indices
contextPos := make([]int32, seqLen*seqLen)
memoryPos := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen; i++ {
for j := int32(0); j < seqLen; j++ {
contextPos[i*seqLen+j] = i
memoryPos[i*seqLen+j] = j
}
}
// Compute relative positions and bucket them
buckets := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen*seqLen; i++ {
relPos := memoryPos[i] - contextPos[i]
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
}
// Create bucket indices array
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
// Transpose to [numHeads, seqLen, seqLen]
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
return bias
}
// relativePosistionBucket computes the bucket for a relative position
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
var bucket int32 = 0
var n int32 = -relativePosition
if bidirectional {
numBuckets /= 2
if n < 0 {
bucket += numBuckets
n = -n
}
} else {
if n < 0 {
n = 0
}
}
// Half buckets are for exact positions, half are for log-spaced
maxExact := numBuckets / 2
if n < maxExact {
bucket += n
} else {
// Log-spaced buckets
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
if bucket > numBuckets-1 {
bucket = numBuckets - 1
}
}
return bucket
}
// Forward for T5Block
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Self attention with residual
h := b.Layer0.Forward(x, posBias, eps)
// FFN with residual
h = b.Layer1.Forward(h, eps)
return h
}
// Forward for T5LayerSelfAttention
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// Attention
attnOut := l.SelfAttention.Forward(normed, posBias)
// Residual
return mlx.Add(x, attnOut)
}
// Forward for T5Attention
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Q, K, V projections (no bias)
// Weights are [out_features, in_features], so we use matmul with transpose
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
// Reshape to [B, L, nheads, d_kv]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
// Transpose to [B, nheads, L, d_kv]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Attention scores with relative position bias
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
// (no 1/sqrt(d_k) scale factor like in standard transformers)
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
scores = mlx.Add(scores, posBias)
// Softmax
attnWeights := mlx.Softmax(scores, -1)
// Attend to values
out := mlx.Matmul(attnWeights, v)
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, D]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
_ = D // Silence unused warning
return out
}
// Forward for T5LayerFF
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// FFN
ffOut := l.DenseReluDense.Forward(normed)
// Residual
return mlx.Add(x, ffOut)
}
// geluNew implements the GELU activation with tanh approximation (gelu_new)
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
func geluNew(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
coeff := float32(0.044715)
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// Forward for T5DenseGatedGelu (gated-gelu activation)
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
gate = geluNew(gate)
// Up projection
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
// Gated output
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
}
// Forward for T5LayerNorm (RMSNorm variant)
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
variance := mlx.Mean(mlx.Square(x), -1, true)
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
return mlx.Mul(x, ln.Weight)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,477 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"` // 3
OutChannels int32 `json:"out_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 16
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
LayersPerBlock int32 `json:"layers_per_block"` // 3
NormNumGroups int32 `json:"norm_num_groups"` // 32
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
ShiftFactor *float32 `json:"shift_factor"` // null
LatentsMean []float32 `json:"latents_mean"` // [16 values]
LatentsStd []float32 `json:"latents_std"` // [16 values]
}
// VAEDecoder is the VAE latent decoder
type VAEDecoder struct {
Config *VAEConfig
// Decoder components
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
}
// VAEConv2d is a 2D convolution layer
type VAEConv2d struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Stride int32
Padding int32
}
// GroupNorm is group normalization
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// VAEMidBlock is the middle block of the VAE
type VAEMidBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
}
// VAEUpBlock is an upsampling block
type VAEUpBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
}
// VAEResnetBlock is a residual block
type VAEResnetBlock struct {
Norm1 *GroupNorm `weight:"norm1"`
Conv1 *VAEConv2d `weight:"conv1"`
Norm2 *GroupNorm `weight:"norm2"`
Conv2 *VAEConv2d `weight:"conv2"`
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
}
// VAEUpsampler is an upsampling layer
type VAEUpsampler struct {
Conv *VAEConv2d `weight:"conv"`
}
// Load loads the VAE decoder from manifest
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
// And all but the last up_block has an upsampler
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
// All but the last block has upsamplers
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the VAE decoder from a directory path
func (m *VAEDecoder) LoadFromPath(path string) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
func (m *VAEDecoder) initGroupNorms() {
cfg := m.Config
numGroups := cfg.NormNumGroups
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
if m.ConvNormOut != nil {
m.ConvNormOut.NumGroups = numGroups
m.ConvNormOut.Eps = eps
}
if m.MidBlock != nil {
for _, resnet := range m.MidBlock.Resnets {
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
for _, upBlock := range m.UpBlocks {
if upBlock == nil {
continue
}
for _, resnet := range upBlock.Resnets {
if resnet == nil {
continue
}
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
}
// Decode decodes latents to an image
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Apply latent denormalization if mean/std are provided
// This matches diffusers GLM-Image: latents = latents * std + mean
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
latents = m.denormalizeLatents(latents)
}
// Convert from NCHW to NHWC for processing
// [B, C, H, W] -> [B, H, W, C]
x := mlx.Transpose(latents, 0, 2, 3, 1)
// Initial convolution
x = m.ConvIn.Forward(x)
// Mid block
x = m.MidBlock.Forward(x)
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
for i := 0; i < len(m.UpBlocks); i++ {
if m.UpBlocks[i] != nil {
x = m.UpBlocks[i].Forward(x)
}
}
// Final normalization and convolution
x = m.ConvNormOut.Forward(x)
x = mlx.SiLU(x)
x = m.ConvOut.Forward(x)
// Convert back to NCHW
// [B, H, W, C] -> [B, C, H, W]
x = mlx.Transpose(x, 0, 3, 1, 2)
// Clamp to valid range and convert to [0, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.AddScalar(x, 1.0)
x = mlx.DivScalar(x, 2.0)
return x
}
// denormalizeLatents applies the latent mean/std denormalization
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Create mean and std arrays [1, C, 1, 1] for broadcasting
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
// Denormalize: latents * std + mean
latents = mlx.Mul(latents, std)
latents = mlx.Add(latents, mean)
return latents
}
// Forward for VAEConv2d
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C_in] (NHWC)
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
// So we need to transpose from OIHW to OHWI
stride := c.Stride
if stride == 0 {
stride = 1
}
padding := c.Padding
if padding == 0 {
// Default to same padding for 3x3 kernels
wShape := c.Weight.Shape()
if len(wShape) >= 3 && wShape[2] == 3 {
padding = 1
}
}
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
out := mlx.Conv2d(x, weight, stride, padding)
if c.Bias != nil {
// Bias: [C_out] -> [1, 1, 1, C_out]
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
out = mlx.Add(out, bias)
}
return out
}
// Forward for GroupNorm
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
numGroups := gn.NumGroups
if numGroups == 0 {
numGroups = 32
}
groupSize := C / numGroups
// Reshape to [B, H, W, groups, groupSize]
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Forward for VAEMidBlock
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range mb.Resnets {
x = resnet.Forward(x)
}
return x
}
// Forward for VAEUpBlock
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
// Apply resnets
for _, resnet := range ub.Resnets {
if resnet != nil {
x = resnet.Forward(x)
}
}
// Apply upsamplers
for _, upsampler := range ub.Upsamplers {
if upsampler != nil {
x = upsampler.Forward(x)
}
}
return x
}
// Forward for VAEResnetBlock
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
// First norm + activation + conv
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
// Second norm + activation + conv
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
// Shortcut for channel mismatch
if rb.ConvShortcut != nil {
residual = rb.ConvShortcut.Forward(residual)
}
return mlx.Add(h, residual)
}
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C]
// 2x nearest neighbor upsample
x = upsample2x(x)
// Conv
if us.Conv != nil {
x = us.Conv.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
hIndices := make([]int32, H*2)
for i := int32(0); i < H; i++ {
hIndices[i*2] = i
hIndices[i*2+1] = i
}
wIndices := make([]int32, W*2)
for i := int32(0); i < W; i++ {
wIndices[i*2] = i
wIndices[i*2+1] = i
}
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
// Take along height axis
x = mlx.Reshape(x, B*H, W, C)
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
x = mlx.Reshape(x, B, H, W*2, C)
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
x = mlx.Reshape(x, B*(W*2), H, C)
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
x = mlx.Reshape(x, B, W*2, H*2, C)
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
return x
}

View File

@@ -0,0 +1,982 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VisionLanguageConfig holds GLM-Image AR generator configuration
type VisionLanguageConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 4096
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
IntermediateSize int32 `json:"intermediate_size"` // 13696
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
VocabSize int32 `json:"vocab_size"` // 168064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
// RoPE config
RopeTheta float32 `json:"rope_theta"` // 10000
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
// Visual token config
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
ImageTokenID int32 `json:"image_token_id"` // 167855
// Computed
HeadDim int32
}
// VisionLanguageEncoder is the 9B AR generator
type VisionLanguageEncoder struct {
Config *VisionLanguageConfig
// Embedding
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
// Transformer layers
Layers []*GLMBlock `weight:"model.language_model.layers"`
// Final norm
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
// LM Head
LMHead *mlx.Array `weight:"lm_head.weight"`
}
// GLMBlock is a single transformer block in GLM-4 style
type GLMBlock struct {
// Pre-attention norm (GLM uses post-LN variant)
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
// Attention
SelfAttn *GLMAttention `weight:"self_attn"`
// MLP (fused gate_up)
MLP *GLMMLP `weight:"mlp"`
}
// GLMAttention implements GQA with partial rotary and MRoPE
type GLMAttention struct {
QProj *mlx.Array `weight:"q_proj.weight"`
KProj *mlx.Array `weight:"k_proj.weight"`
VProj *mlx.Array `weight:"v_proj.weight"`
OProj *mlx.Array `weight:"o_proj.weight"`
// QKV have biases in GLM
QBias *mlx.Array `weight:"q_proj.bias"`
KBias *mlx.Array `weight:"k_proj.bias"`
VBias *mlx.Array `weight:"v_proj.bias"`
// Computed
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
PartialRotary float32 // Only rotate this fraction of head_dim
RopeTheta float32
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
}
// ARCache holds KV caches for all layers using the shared cache implementation
type ARCache struct {
Layers []cache.Cache
}
// NewARCache creates a new cache for the given number of layers
func NewARCache(numLayers int32) *ARCache {
layers := make([]cache.Cache, numLayers)
for i := range layers {
layers[i] = cache.NewKVCache()
}
return &ARCache{Layers: layers}
}
// Free releases all cached tensors
func (c *ARCache) Free() {
for _, layer := range c.Layers {
for _, arr := range layer.State() {
if arr != nil {
arr.Free()
}
}
}
}
// GLMMLP implements fused gate_up SwiGLU MLP
type GLMMLP struct {
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
DownProj *mlx.Array `weight:"down_proj.weight"`
}
// Load loads the vision-language encoder from manifest
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
return fmt.Errorf("config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
// LoadFromPath loads the vision-language encoder from a directory path
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &rawCfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
func (m *VisionLanguageEncoder) initComputedFields() {
cfg := m.Config
for _, block := range m.Layers {
block.SelfAttn.NHeads = cfg.NumAttentionHeads
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
block.SelfAttn.HeadDim = cfg.HeadDim
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
block.SelfAttn.RopeTheta = cfg.RopeTheta
block.SelfAttn.MRoPESection = cfg.MRoPESection
// Set norm eps
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
}
m.FinalNorm.Eps = cfg.RMSNormEps
}
// Generate autoregressively generates visual tokens with KV caching
func (m *VisionLanguageEncoder) Generate(
prompt string,
tok *GLMTokenizer,
maxTokens int32,
temperature float32,
topP float32,
seed int64,
targetHeight, targetWidth int32,
progressFn func(int),
) *mlx.Array {
cfg := m.Config
// Encode prompt with grid tokens using GLM tokenizer
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
// Calculate grid dimensions for MRoPE position IDs
factor := int32(32)
tokenH := targetHeight / factor
tokenW := targetWidth / factor
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
prevGridSize := prevTokenH * prevTokenW
// Create KV cache for all layers
cache := NewARCache(cfg.NumHiddenLayers)
defer cache.Free()
// ===== PREFILL PHASE =====
// Process entire prompt at once, populate cache
promptLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
mlx.Eval(h)
// Compute position IDs for prefill (text tokens use same position for all dims)
prefillPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
prefillPositions[dim] = make([]int32, promptLen)
for i := int32(0); i < promptLen; i++ {
prefillPositions[dim][i] = i
}
}
// Forward through layers (prefill)
for i, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
if i > 0 {
oldH.Free()
}
}
// Eval h and cache arrays together so cache is materialized
evalArgs := []*mlx.Array{h}
for _, lc := range cache.Layers {
evalArgs = append(evalArgs, lc.State()...)
}
mlx.Eval(evalArgs...)
// Final norm and get logits for last position
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
h.Free()
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
lastH.Free()
// Sample first token
var sampleCounter int64 = 0
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
// AR generation loop with caching
// Visual tokens are stored as VQ codebook indices [0, 16383]
// The LM head outputs indices [0, 16511] where:
// - [0, 16383] are VQ codes
// - 16384 is BOS
// - 16385 is EOS
visualTokens := make([]int32, 0, maxTokens)
posOffset := promptLen
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
// Preallocate slice for old cache state to reuse
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for i := int32(0); i < maxTokens; i++ {
if progressFn != nil {
progressFn(int(i))
}
// Check for end token (EOS = 16385)
if nextToken == cfg.ImageEndTokenID {
break
}
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
if nextToken == cfg.ImageStartTokenID {
// BOS token - skip storing but continue generation
} else if nextToken < cfg.ImageStartTokenID {
// This is an actual VQ code [0, 16383] - store it
visualTokens = append(visualTokens, nextToken)
}
// Tokens >= 16386 are other special tokens, skip them
// ===== DECODE PHASE =====
// Save old cache state before forward (to free after eval)
oldCacheState = oldCacheState[:0]
for _, lc := range cache.Layers {
oldCacheState = append(oldCacheState, lc.State()...)
}
// Only process the new token, use cached K,V
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
// Compute MRoPE position IDs for this visual token
// Visual tokens are arranged in two grids: prev grid then target grid
// Position dimensions: [temporal, height, width]
decodePositions := computeVisualTokenPositions(
visualTokenIdx, posOffset, promptLen,
prevTokenH, prevTokenW, prevGridSize,
tokenH, tokenW,
)
// Forward through layers (decode with cache)
for j, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
if j > 0 { // Don't free the embedding on first layer
oldH.Free()
}
}
// Eval h and new cache state
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for _, lc := range cache.Layers {
newCacheState = append(newCacheState, lc.State()...)
}
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
// Free old cache state (now that new state is evaluated)
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
// Final norm
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
// Get logits (h is already [1, 1, hidden_size])
h = mlx.Reshape(h, 1, cfg.HiddenSize)
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
h.Free()
// Sample next token
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
posOffset++
visualTokenIdx++
// Periodically clear cache to release intermediate memory
if i%256 == 0 {
mlx.ClearCache()
}
}
if len(visualTokens) == 0 {
// Return at least one token to avoid empty tensor issues
visualTokens = append(visualTokens, 0)
}
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
}
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
// Returns [3][1] position IDs for temporal, height, and width dimensions
//
// MRoPE position encoding for GLM-Image visual tokens:
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
// - height: decode_pos + row index within grid
// - width: decode_pos + column index within grid
//
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
// sufficient positional separation.
func computeVisualTokenPositions(
visualIdx int32, absPos int32, promptLen int32,
prevH, prevW, prevSize int32,
targetH, targetW int32,
) [][]int32 {
positions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
positions[dim] = make([]int32, 1)
}
// First grid (prev grid) starts at decode_pos = promptLen
prevGridDecodePos := promptLen
// Second grid (target grid) starts after first grid
// next_pos = prev_decode_pos + max(prevH, prevW)
maxPrev := prevH
if prevW > maxPrev {
maxPrev = prevW
}
targetGridDecodePos := prevGridDecodePos + maxPrev
// Compute position IDs based on which grid the token is in
if visualIdx < prevSize {
// Token is in the prev grid (prev_token_h × prev_token_w)
row := visualIdx / prevW
col := visualIdx % prevW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = prevGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = prevGridDecodePos + row
positions[2][0] = prevGridDecodePos + col
} else {
// Token is in the target grid (token_h × token_w)
targetIdx := visualIdx - prevSize
row := targetIdx / targetW
col := targetIdx % targetW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = targetGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = targetGridDecodePos + row
positions[2][0] = targetGridDecodePos + col
}
_ = targetH // Used for documentation clarity
_ = absPos // No longer used - kept for API compatibility
return positions
}
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
// sampleCounter is incremented for each call to ensure different random values
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
// Output index directly corresponds to vocab ID [0, 16511]
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
visualLogits := logits
// Apply temperature
if temperature != 1.0 && temperature > 0 {
visualLogits = mlx.DivScalar(visualLogits, temperature)
}
// Apply softmax to get probabilities
probs := mlx.Softmax(visualLogits, -1)
mlx.Eval(probs)
// Get the sampled index using top-p sampling
// This directly gives us the vocab ID in [0, 16511]
// Special tokens: 16384 = BOS, 16385 = EOS
// Use seed + counter for reproducible but different random values
effectiveSeed := seed + *sampleCounter
*sampleCounter++
return sampleTopP(probs, topP, effectiveSeed)
}
// sampleTopP implements nucleus (top-p) sampling
// probs: [1, vocab_size] probability distribution
// topP: cumulative probability threshold (e.g., 0.75)
// seed: random seed for reproducible sampling
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
// Negate probs for descending sort (Argsort only does ascending)
negProbs := mlx.MulScalar(probs, -1)
sortedIndices := mlx.Argsort(negProbs, -1)
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
cumProbs := mlx.Cumsum(sortedProbs, -1)
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
// Find cutoff index where cumulative probability exceeds topP
probsData := sortedProbs.Data()
cumProbsData := cumProbs.Data()
indicesData := sortedIndices.DataInt32()
// Calculate cutoff and renormalize
var cutoffIdx int
var totalProb float32
for i, cp := range cumProbsData {
totalProb += probsData[i]
if cp >= topP {
cutoffIdx = i + 1 // Include this token
break
}
}
if cutoffIdx == 0 {
cutoffIdx = len(probsData) // Use all tokens if topP is very high
}
// Sample from the truncated distribution
// Renormalize the truncated probabilities
truncatedProbs := make([]float32, cutoffIdx)
for i := 0; i < cutoffIdx; i++ {
truncatedProbs[i] = probsData[i] / totalProb
}
// Sample using random number with provided seed for reproducibility
r := mlx.RandomUniform([]int32{1}, uint64(seed))
mlx.Eval(r)
randVal := r.Data()[0]
// Find the sampled token
var cumulative float32
for i := 0; i < cutoffIdx; i++ {
cumulative += truncatedProbs[i]
if randVal < cumulative {
return indicesData[i]
}
}
// Fallback to the last token in truncated set
return indicesData[cutoffIdx-1]
}
// Forward for GLMBlock
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
}
// ForwardWithCache performs block forward with optional KV caching and MRoPE
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
// Pre-attention norm
normed := b.InputLayerNorm.Forward(x, eps)
// Self-attention with RoPE/MRoPE and cache
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
// Post-attention norm (GLM-4 style)
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
// Residual connection
x = mlx.Add(x, attnOut)
// Post-attention layer norm
normed = b.PostAttnLayerNorm.Forward(x, eps)
// MLP
mlpOut := b.MLP.Forward(normed)
// Post-MLP norm
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
// Residual connection
x = mlx.Add(x, mlpOut)
return x
}
// Forward for GLMAttention (without cache - used for prefill)
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
}
// ForwardWithCache performs attention with optional KV caching and MRoPE
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
// kvcache is updated in-place if provided
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
// Q, K, V projections
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
// Add biases
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// Apply partial RoPE or MRoPE
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
if len(attn.MRoPESection) == 3 && positionIDs != nil {
// Use MRoPE with explicit position IDs
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else if len(attn.MRoPESection) == 3 {
// Use MRoPE with sequential positions (same for all dims - text mode)
seqPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
seqPositions[dim] = make([]int32, L)
for i := int32(0); i < L; i++ {
seqPositions[dim][i] = i + posOffset
}
}
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else {
// Fallback to standard RoPE
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Update cache and get full K, V for attention
if kvcache != nil {
k, v = kvcache.Update(k, v, int(L))
}
// Repeat KV for GQA
kExpanded := k
vExpanded := v
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
kExpanded = repeatKV(k, repeats)
vExpanded = repeatKV(v, repeats)
}
// Scaled dot-product attention with causal mask
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, hidden_size]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
return out
}
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
}
// applyPartialRoPEWithOffset applies RoPE with a position offset
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply RoPE to rotary part with position offset
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
// mrope_section: [8, 12, 12] - frequency pairs per dimension
// For text tokens: all 3 dimensions have the same sequential position
// For image tokens: temporal=seq_idx, height=row, width=col
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply MRoPE to rotary part
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyMRoPE applies multi-dimensional rotary position embedding
// x: [B, L, H, D] where D is the rotary dimension
// positionIDs: [3][L] - positions for temporal, height, width dimensions
// mropeSection: [8, 12, 12] - frequency pairs per dimension
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Validate mrope_section sums to half (number of frequency pairs)
var totalPairs int32
for _, s := range mropeSection {
totalPairs += s
}
if totalPairs != half {
// Fallback to standard RoPE if section doesn't match
return applyRoPEWithOffset(x, L, 0, theta)
}
// Build angles for each position dimension (matching Python's MRoPE approach)
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
angleVals := make([]*mlx.Array, 3)
freqOffset := int32(0)
for dim := 0; dim < 3; dim++ {
numPairs := mropeSection[dim]
if numPairs == 0 {
continue
}
// Compute inverse frequencies for this section
// Each dimension uses DIFFERENT frequency ranges:
// - Temporal: frequencies 0 to section[0]-1
// - Height: frequencies section[0] to section[0]+section[1]-1
// - Width: frequencies section[0]+section[1] to sum(section)-1
freqsArr := make([]float32, numPairs)
for i := int32(0); i < numPairs; i++ {
globalIdx := freqOffset + i
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
// Position indices for this dimension
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(positionIDs[dim][i])
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, numPairs] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
freqOffset += numPairs
}
// Concatenate all sections: [L, half] = [L, 32]
allAngles := mlx.Concatenate(angleVals, 1)
// Duplicate AFTER concatenation: [L, D] = [L, 64]
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
// Compute cos/sin
allCos := mlx.Cos(allAngles)
allSin := mlx.Sin(allAngles)
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
allCos = mlx.Reshape(allCos, 1, L, 1, D)
allSin = mlx.Reshape(allSin, 1, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
}
// applyRoPE applies rotary position embedding
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
return applyRoPEWithOffset(x, seqLen, 0, theta)
}
// applyRoPEWithOffset applies rotary position embedding with position offset
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Compute inverse frequencies: 1 / (theta^(2i/d))
freqsArr := make([]float32, half)
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
// Position indices with offset
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(i + posOffset)
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, half] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
angles := mlx.Mul(posExpanded, freqsExpanded)
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
cosVals := mlx.Cos(anglesDup)
sinVals := mlx.Sin(anglesDup)
cosVals = mlx.Reshape(cosVals, L, 1, D)
sinVals = mlx.Reshape(sinVals, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
// x: [B, nkvheads, L, head_dim]
x = mlx.ExpandDims(x, 2)
// x: [B, nkvheads, 1, L, head_dim]
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
// x: [B, nkvheads, repeats, L, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Forward for GLMMLP (fused gate_up SwiGLU)
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
// gate_up_proj outputs [gate, up] concatenated
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
shape := gateUp.Shape()
halfDim := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
gate = mlx.SiLU(gate)
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
}

View File

@@ -311,8 +311,8 @@ type Model struct {
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))

View File

@@ -222,6 +222,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -264,10 +272,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -305,6 +322,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {

View File

@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()

View File

@@ -28,12 +28,12 @@ type Qwen3Config struct {
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP

View File

@@ -36,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
Linear nn.LinearLayer `weight:"2-1"`
}
// Forward embeds patchified image latents
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
@@ -100,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -115,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
@@ -126,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
}
// Forward computes attention
@@ -146,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -227,7 +294,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -281,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
@@ -350,12 +417,11 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
@@ -377,7 +443,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
@@ -391,6 +457,20 @@ func (m *Transformer) initComputedFields() {
}
}
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
@@ -404,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
@@ -423,6 +503,8 @@ type RoPECache struct {
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -456,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}

View File

@@ -104,6 +104,8 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
@@ -205,6 +207,10 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
@@ -734,18 +740,26 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev := h
prev = h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
@@ -754,7 +768,6 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h)
return h

View File

@@ -26,10 +26,12 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
}
// ProgressFunc is called during generation with step progress.
@@ -42,6 +44,7 @@ type Model struct {
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from ollama blob storage.
@@ -196,13 +199,17 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
}
useCFG := cfg.NegativePrompt != ""
@@ -260,12 +267,54 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
}
// Denoising loop
@@ -277,6 +326,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
@@ -289,50 +339,77 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
var noisePred *mlx.Array
patches := PatchifyLatents(latents, tcfg.PatchSize)
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
noisePred = mlx.Add(negPred, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
}
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
@@ -361,8 +438,14 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
}
// VAE decode

View File

@@ -10,6 +10,13 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w)
}
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out
}
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`

22
x/imagegen/quantize.go Normal file
View File

@@ -0,0 +1,22 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -13,16 +13,21 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
}
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
@@ -34,15 +39,17 @@ type Request struct {
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content"`
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
model ImageModel
modelName string
modelType string // "zimage" or "glm_image"
}
// Execute is the entry point for the image runner subprocess
@@ -72,15 +79,35 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
// Detect model type and load appropriate model
modelType, err := detectModelType(*modelName)
if err != nil {
return fmt.Errorf("failed to detect model type: %w", err)
}
var model ImageModel
switch modelType {
case "GlmImagePipeline":
slog.Info("loading GLM-Image model")
m := &glm_image.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load GLM-Image model: %w", err)
}
model = m
default:
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
slog.Info("loading Z-Image model")
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load Z-Image model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
modelType: modelType,
}
// Set up HTTP handlers
@@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
// Default steps depend on model type
switch s.modelType {
case "GlmImagePipeline":
req.Steps = 50 // GLM-Image default
default:
req.Steps = 9 // Z-Image turbo default
}
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
@@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image
// Generate image using interface method
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
if err != nil {
// Don't send error for cancellation
@@ -191,10 +208,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Save image
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
@@ -204,14 +221,47 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response
// Send final response with image data
resp := Response{
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true,
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}
// detectModelType reads the model manifest and returns the pipeline class name
func detectModelType(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", err
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return "ZImagePipeline", nil // Default to Z-Image
}
// Try both _class_name (diffusers format) and architecture (ollama format)
var index struct {
ClassName string `json:"_class_name"`
Architecture string `json:"architecture"`
}
if err := json.Unmarshal(data, &index); err != nil {
return "ZImagePipeline", nil
}
// Prefer _class_name, fall back to architecture
className := index.ClassName
if className == "" {
className = index.Architecture
}
if className == "" {
return "ZImagePipeline", nil
}
return className, nil
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// WeightSource is an interface for loading weights.
@@ -102,6 +103,22 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -176,3 +193,64 @@ func joinPath(prefix, suffix string) string {
}
return prefix + "." + suffix
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -14,7 +14,9 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
@@ -46,7 +48,8 @@ type completionRequest struct {
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content"`
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Done bool `json:"done"`
}
@@ -69,7 +72,7 @@ func NewServer(modelName string) (*Server, error) {
port = rand.Intn(65535-49152) + 49152
}
// Get the ollama executable path
// Get the ollama-mlx executable path (in same directory as current executable)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
@@ -77,11 +80,42 @@ func NewServer(modelName string) (*Server, error) {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
s := &Server{
cmd: cmd,
port: port,
@@ -112,7 +146,7 @@ func NewServer(modelName string) (*Server, error) {
}
}()
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
}
@@ -250,15 +284,23 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
}
// Stream responses
// Stream responses - use large buffer for base64 image data
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue
}
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
}
fn(llm.CompletionResponse{
Content: cresp.Content,
Content: content,
Done: cresp.Done,
})
if cresp.Done {

View File

@@ -1082,12 +1082,6 @@ func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
return id, ok
}
// Vocab returns the vocabulary as a slice of token strings indexed by token ID.
// This is useful for constrained decoding where we need to map tokens to grammar symbols.
func (t *Tokenizer) Vocab() []string {
return t.vocab.Values
}
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
func LoadVocabMerges(dir string) (*Tokenizer, error) {
vocabPath := dir + "/vocab.json"

View File

@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
return nil
}
// Filter existing
var blobs []Blob
// Calculate total from all blobs (for accurate progress reporting on resume)
var total int64
for _, b := range opts.Blobs {
total += b.Size
}
// Filter out already-downloaded blobs and track completed bytes
var blobs []Blob
var alreadyCompleted int64
for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
}
alreadyCompleted += b.Size
continue
}
blobs = append(blobs, b)
total += b.Size
}
if len(blobs) == 0 {
return nil
}
token := opts.Token
progress := newProgressTracker(total, opts.Progress)
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
d := &downloader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: newProgressTracker(total, opts.Progress),
progress: progress,
speeds: &speedTracker{},
logger: opts.Logger,
}

View File

@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
},
Timeout: 5 * time.Minute,
// Don't follow redirects automatically - we handle them manually
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},

View File

@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
}
}
func TestDownloadResumeProgressTotal(t *testing.T) {
// Test that when resuming a download with some blobs already present:
// 1. Total reflects ALL blob sizes (not just remaining)
// 2. Completed starts at the size of already-downloaded blobs
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1000)
blob2, data2 := createTestBlob(t, serverDir, 2000)
blob3, data3 := createTestBlob(t, serverDir, 3000)
// Pre-populate client with blob1 and blob2 (simulating partial download)
clientDir := t.TempDir()
for _, b := range []struct {
blob Blob
data []byte
}{{blob1, data1}, {blob2, data2}} {
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, b.data, 0o644); err != nil {
t.Fatal(err)
}
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
var firstCompleted, firstTotal int64
var gotFirstProgress bool
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
Progress: func(completed, total int64) {
mu.Lock()
defer mu.Unlock()
if !gotFirstProgress {
firstCompleted = completed
firstTotal = total
gotFirstProgress = true
}
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Total should be sum of ALL blobs, not just blob3
expectedTotal := blob1.Size + blob2.Size + blob3.Size
if firstTotal != expectedTotal {
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
}
// First progress call should show already-completed bytes from blob1+blob2
expectedCompleted := blob1.Size + blob2.Size
if firstCompleted < expectedCompleted {
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
}
// Verify blob3 was downloaded
verifyBlob(t, clientDir, blob3, data3)
}
func TestDownloadDigestMismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return wrong data

View File

@@ -54,6 +54,16 @@ func (r *Registry) RegisterBash() {
r.Register(&BashTool{})
}
// RegisterWebSearch adds the web search tool to the registry.
func (r *Registry) RegisterWebSearch() {
r.Register(&WebSearchTool{})
}
// RegisterWebFetch adds the web fetch tool to the registry.
func (r *Registry) RegisterWebFetch() {
r.Register(&WebFetchTool{})
}
// Get retrieves a tool by name.
func (r *Registry) Get(name string) (Tool, bool) {
tool, ok := r.tools[name]

162
x/tools/webfetch.go Normal file
View File

@@ -0,0 +1,162 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
)
const (
webFetchAPI = "https://ollama.com/api/web_fetch"
webFetchTimeout = 30 * time.Second
)
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
// WebFetchTool implements web page fetching using Ollama's hosted API.
type WebFetchTool struct{}
// Name returns the tool name.
func (w *WebFetchTool) Name() string {
return "web_fetch"
}
// Description returns a description of the tool.
func (w *WebFetchTool) Description() string {
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
}
// Schema returns the tool's parameter schema.
func (w *WebFetchTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("url", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The URL to fetch and extract content from",
})
return api.ToolFunction{
Name: w.Name(),
Description: w.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"url"},
},
}
}
// webFetchRequest is the request body for the web fetch API.
type webFetchRequest struct {
URL string `json:"url"`
}
// webFetchResponse is the response from the web fetch API.
type webFetchResponse struct {
Title string `json:"title"`
Content string `json:"content"`
Links []string `json:"links,omitempty"`
}
// Execute fetches content from a web page.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
urlStr, ok := args["url"].(string)
if !ok || urlStr == "" {
return "", fmt.Errorf("url parameter is required")
}
// Validate URL
if _, err := url.Parse(urlStr); err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
// Prepare request
reqBody := webFetchRequest{
URL: urlStr,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshaling request: %w", err)
}
// Parse URL and add timestamp for signing
fetchURL, err := url.Parse(webFetchAPI)
if err != nil {
return "", fmt.Errorf("parsing fetch URL: %w", err)
}
q := fetchURL.Query()
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
fetchURL.RawQuery = q.Encode()
// Sign the request using Ollama key (~/.ollama/id_ed25519)
ctx := context.Background()
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", fmt.Errorf("signing request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
// Send request
client := &http.Client{Timeout: webFetchTimeout}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return "", ErrWebFetchAuthRequired
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var fetchResp webFetchResponse
if err := json.Unmarshal(body, &fetchResp); err != nil {
return "", fmt.Errorf("parsing response: %w", err)
}
// Format result
var sb strings.Builder
if fetchResp.Title != "" {
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
}
if fetchResp.Content != "" {
sb.WriteString("Content:\n")
sb.WriteString(fetchResp.Content)
} else {
sb.WriteString("No content could be extracted from the page.")
}
return sb.String(), nil
}